diff --git a/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala b/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala index 00e131ed..64f76eba 100644 --- a/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala @@ -138,74 +138,30 @@ object XSDToSchema { case unsupported => throw new IllegalArgumentException(s"Unsupported content: $unsupported") } + case content: XmlSchemaComplexContent => + val complexContent = content.getContent + complexContent match { + case extension: XmlSchemaComplexContentExtension => + val baseStructField = getStructField(xmlSchema, + xmlSchema.getParent.getTypeByQName(extension.getBaseTypeName)) + val baseFields = baseStructField.dataType match { + case structType: StructType => structType.fields + case others => + throw new IllegalArgumentException( + s"Non-StructType in ComplexContentExtension: $others" + ) + } + + val extendedFields = getStructFieldsFromParticle(extension.getParticle, xmlSchema) + StructField( + schemaType.getQName.getLocalPart, + StructType(baseFields ++ extendedFields) + ) + case unsupported => + throw new IllegalArgumentException(s"Unsupported content: $unsupported") + } case null => - val childFields = - complexType.getParticle match { - // xs:all - case all: XmlSchemaAll => - all.getItems.asScala.map { - case element: XmlSchemaElement => - val baseStructField = getStructField(xmlSchema, element.getSchemaType) - val nullable = element.getMinOccurs == 0 - if (element.getMaxOccurs == 1) { - StructField(element.getName, baseStructField.dataType, nullable) - } else { - StructField(element.getName, ArrayType(baseStructField.dataType), nullable) - } - }.toSeq - // xs:choice - case choice: XmlSchemaChoice => - choice.getItems.asScala.map { - case element: XmlSchemaElement => - val baseStructField = getStructField(xmlSchema, element.getSchemaType) - if (element.getMaxOccurs == 1) { - StructField(element.getName, baseStructField.dataType, true) - } else { - StructField(element.getName, ArrayType(baseStructField.dataType), true) - } - case any: XmlSchemaAny => - val dataType = if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType - StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, true) - }.toSeq - // xs:sequence - case sequence: XmlSchemaSequence => - // flatten xs:choice nodes - sequence.getItems.asScala.flatMap { _ match { - case choice: XmlSchemaChoice => - choice.getItems.asScala.map { e => - val xme = e.asInstanceOf[XmlSchemaElement] - val baseType = getStructField(xmlSchema, xme.getSchemaType).dataType - val dataType = if (xme.getMaxOccurs > 1) ArrayType(baseType) else baseType - StructField(xme.getName, dataType, true) - } - case e: XmlSchemaElement => - val refQName = e.getRef.getTargetQName - val baseType = - if (refQName != null) { - getStructField( - xmlSchema, - xmlSchema.getParent.getElementByQName(refQName).getSchemaType).dataType - } - else getStructField(xmlSchema, e.getSchemaType).dataType - val dataType = if (e.getMaxOccurs > 1) ArrayType(baseType) else baseType - val nullable = e.getMinOccurs == 0 - val structFieldName = - Option(refQName).map(_.getLocalPart).getOrElse(e.getName) - Seq(StructField(structFieldName, dataType, nullable)) - case any: XmlSchemaAny => - val dataType = - if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType - val nullable = any.getMinOccurs == 0 - Seq(StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, nullable)) - case unsupported => - throw new IllegalArgumentException(s"Unsupported item: $unsupported") - } - }.toSeq - case null => - Seq.empty - case unsupported => - throw new IllegalArgumentException(s"Unsupported particle: $unsupported") - } + val childFields = getStructFieldsFromParticle(complexType.getParticle, xmlSchema) val attributes = complexType.getAttributes.asScala.map { case attribute: XmlSchemaAttribute => val attributeType = attribute.getSchemaTypeName match { @@ -237,4 +193,76 @@ object XSDToSchema { }) } + private def getStructFieldsFromParticle( + particle: XmlSchemaParticle, + xmlSchema: XmlSchema + ): Seq[StructField] = { + particle match { + // xs:all + case all: XmlSchemaAll => + all.getItems.asScala.map { + case element: XmlSchemaElement => + val baseStructField = getStructField(xmlSchema, element.getSchemaType) + val nullable = element.getMinOccurs == 0 + if (element.getMaxOccurs == 1) { + StructField(element.getName, baseStructField.dataType, nullable) + } else { + StructField(element.getName, ArrayType(baseStructField.dataType), nullable) + } + }.toSeq + // xs:choice + case choice: XmlSchemaChoice => + choice.getItems.asScala.map { + case element: XmlSchemaElement => + val baseStructField = getStructField(xmlSchema, element.getSchemaType) + if (element.getMaxOccurs == 1) { + StructField(element.getName, baseStructField.dataType, true) + } else { + StructField(element.getName, ArrayType(baseStructField.dataType), true) + } + case any: XmlSchemaAny => + val dataType = if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType + StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, true) + }.toSeq + // xs:sequence + case sequence: XmlSchemaSequence => + // flatten xs:choice nodes + sequence.getItems.asScala.flatMap { + _ match { + case choice: XmlSchemaChoice => + choice.getItems.asScala.map { e => + val xme = e.asInstanceOf[XmlSchemaElement] + val baseType = getStructField(xmlSchema, xme.getSchemaType).dataType + val dataType = if (xme.getMaxOccurs > 1) ArrayType(baseType) else baseType + StructField(xme.getName, dataType, true) + } + case e: XmlSchemaElement => + val refQName = e.getRef.getTargetQName + val baseType = + if (refQName != null) { + getStructField( + xmlSchema, + xmlSchema.getParent.getElementByQName(refQName).getSchemaType).dataType + } + else getStructField(xmlSchema, e.getSchemaType).dataType + val dataType = if (e.getMaxOccurs > 1) ArrayType(baseType) else baseType + val nullable = e.getMinOccurs == 0 + val structFieldName = + Option(refQName).map(_.getLocalPart).getOrElse(e.getName) + Seq(StructField(structFieldName, dataType, nullable)) + case any: XmlSchemaAny => + val dataType = + if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType + val nullable = any.getMinOccurs == 0 + Seq(StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, nullable)) + case unsupported => + throw new IllegalArgumentException(s"Unsupported item: $unsupported") + } + }.toSeq + case null => + Seq.empty + case unsupported => + throw new IllegalArgumentException(s"Unsupported particle: $unsupported") + } + } } diff --git a/src/test/resources/complex-content-extension.xsd b/src/test/resources/complex-content-extension.xsd new file mode 100644 index 00000000..f1371930 --- /dev/null +++ b/src/test/resources/complex-content-extension.xsd @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala index 42da366f..8f27099a 100644 --- a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala @@ -152,4 +152,23 @@ class XSDToSchemaSuite extends AnyFunSuite { ) assert(parsedSchema === expectedSchema) } + + test("Test complex content with extension element / Issue 554") { + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/complex-content-extension.xsd")) + + val expectedSchema = buildSchema( + field( + "employee", + struct( + field("firstname", StringType, false), + field("lastname", StringType, false), + field("address", StringType, false), + field("city", StringType, false), + field("country", StringType, false) + ), + false + ) + ) + assert(parsedSchema === expectedSchema) + } }