From 807290c9f9979c897d89261eb62bc9698dc9e9d0 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Apr 2023 21:25:02 -0500 Subject: [PATCH] In XSD, handle frac digits for decimal types. Also try to support custom type declarations in the XSD --- .../spark/xml/util/XSDToSchema.scala | 25 +++++++++++++------ .../resources/decimal-with-restriction.xsd | 7 ++++++ .../spark/xml/util/XSDToSchemaSuite.scala | 9 ++++--- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala b/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala index 00e131ed..8638fe66 100644 --- a/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala @@ -85,16 +85,25 @@ object XSDToSchema { case null => restriction.getBaseTypeName case n => n } - qName match { + + // Hacky, is there a better way? see if the type is known as a custom + // type and use that if so, assuming along the way it's a simple restriction + val typeName = xmlSchema.getSchemaTypes.asScala.get(qName).map { s => + s.asInstanceOf[XmlSchemaSimpleType]. + getContent.asInstanceOf[XmlSchemaSimpleTypeRestriction].getBaseTypeName + }.getOrElse(qName) + + typeName match { case Constants.XSD_BOOLEAN => BooleanType case Constants.XSD_DECIMAL => - val scale = restriction.getFacets.asScala.collectFirst { - case facet: XmlSchemaFractionDigitsFacet => facet - } - scale match { - case Some(scale) => DecimalType(38, scale.getValue.toString.toInt) - case None => DecimalType(38, 18) - } + val facets = restriction.getFacets.asScala + val fracDigits = facets.collectFirst { + case facet: XmlSchemaFractionDigitsFacet => facet.getValue.toString.toInt + }.getOrElse(18) + val totalDigits = facets.collectFirst { + case facet: XmlSchemaTotalDigitsFacet => facet.getValue.toString.toInt + }.getOrElse(38) + DecimalType(totalDigits, math.min(totalDigits, fracDigits)) case Constants.XSD_UNSIGNEDLONG => DecimalType(38, 0) case Constants.XSD_DOUBLE => DoubleType case Constants.XSD_FLOAT => FloatType diff --git a/src/test/resources/decimal-with-restriction.xsd b/src/test/resources/decimal-with-restriction.xsd index 9fd4eb8f..e60cc548 100644 --- a/src/test/resources/decimal-with-restriction.xsd +++ b/src/test/resources/decimal-with-restriction.xsd @@ -1,5 +1,11 @@ + + + + + + @@ -8,4 +14,5 @@ + \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala index 42da366f..4bf534b4 100644 --- a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala @@ -115,13 +115,16 @@ class XSDToSchemaSuite extends AnyFunSuite { val expectedSchema = buildSchema( field("test", struct(field("userId", LongType, nullable = false)), nullable = false)) - assert(expectedSchema === parsedSchema) + assert(parsedSchema === expectedSchema) } test("Test xs:decimal type with restriction[fractionalDigits]") { val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/decimal-with-restriction.xsd")) - val expectedSchema = buildSchema(field("decimal_type_1", DecimalType(38, 18), nullable = false), - field("decimal_type_2", DecimalType(38, 2), nullable = false)) + val expectedSchema = buildSchema( + field("decimal_type_3", DecimalType(12, 6), nullable = false), + field("decimal_type_1", DecimalType(38, 18), nullable = false), + field("decimal_type_2", DecimalType(38, 2), nullable = false) + ) assert(parsedSchema === expectedSchema) }