Skip to content

Commit

Permalink
In XSD, handle frac digits for decimal types. Also try to support cus…
Browse files Browse the repository at this point in the history
…tom type declarations in the XSD (#638)
  • Loading branch information
srowen authored Apr 13, 2023
1 parent ce51382 commit 0fe32ac
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
25 changes: 17 additions & 8 deletions src/main/scala/com/databricks/spark/xml/util/XSDToSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,25 @@ object XSDToSchema {
case null => restriction.getBaseTypeName
case n => n
}
qName match {

// Hacky, is there a better way? see if the type is known as a custom
// type and use that if so, assuming along the way it's a simple restriction
val typeName = xmlSchema.getSchemaTypes.asScala.get(qName).map { s =>
s.asInstanceOf[XmlSchemaSimpleType].
getContent.asInstanceOf[XmlSchemaSimpleTypeRestriction].getBaseTypeName
}.getOrElse(qName)

typeName match {
case Constants.XSD_BOOLEAN => BooleanType
case Constants.XSD_DECIMAL =>
val scale = restriction.getFacets.asScala.collectFirst {
case facet: XmlSchemaFractionDigitsFacet => facet
}
scale match {
case Some(scale) => DecimalType(38, scale.getValue.toString.toInt)
case None => DecimalType(38, 18)
}
val facets = restriction.getFacets.asScala
val fracDigits = facets.collectFirst {
case facet: XmlSchemaFractionDigitsFacet => facet.getValue.toString.toInt
}.getOrElse(18)
val totalDigits = facets.collectFirst {
case facet: XmlSchemaTotalDigitsFacet => facet.getValue.toString.toInt
}.getOrElse(38)
DecimalType(totalDigits, math.min(totalDigits, fracDigits))
case Constants.XSD_UNSIGNEDLONG => DecimalType(38, 0)
case Constants.XSD_DOUBLE => DoubleType
case Constants.XSD_FLOAT => FloatType
Expand Down
7 changes: 7 additions & 0 deletions src/test/resources/decimal-with-restriction.xsd
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
<?xml version="1.0" encoding="UTF-8" ?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:simpleType name="somedecimal">
<xs:restriction base="xs:decimal">
<xs:fractionDigits value="6"/>
<xs:totalDigits value="12"/>
</xs:restriction>
</xs:simpleType>
<xs:element name="decimal_type_1" type="xs:decimal"/>
<xs:element name="decimal_type_2">
<xs:simpleType>
Expand All @@ -8,4 +14,5 @@
</xs:restriction>
</xs:simpleType>
</xs:element>
<xs:element name="decimal_type_3" type="somedecimal"/>
</xs:schema>
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ class XSDToSchemaSuite extends AnyFunSuite {
val expectedSchema = buildSchema(
field("test",
struct(field("userId", LongType, nullable = false)), nullable = false))
assert(expectedSchema === parsedSchema)
assert(parsedSchema === expectedSchema)
}

test("Test xs:decimal type with restriction[fractionalDigits]") {
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/decimal-with-restriction.xsd"))
val expectedSchema = buildSchema(field("decimal_type_1", DecimalType(38, 18), nullable = false),
field("decimal_type_2", DecimalType(38, 2), nullable = false))
val expectedSchema = buildSchema(
field("decimal_type_3", DecimalType(12, 6), nullable = false),
field("decimal_type_1", DecimalType(38, 18), nullable = false),
field("decimal_type_2", DecimalType(38, 2), nullable = false)
)
assert(parsedSchema === expectedSchema)
}

Expand Down

0 comments on commit 0fe32ac

Please sign in to comment.