Skip to content

Commit

Permalink
Add schema_of_xml_df for Python future use (#438)
Browse files Browse the repository at this point in the history
See #435 .
  • Loading branch information
srowen authored Feb 14, 2020
1 parent a71f735 commit 27dce95
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
11 changes: 11 additions & 0 deletions src/main/scala/com/databricks/spark/xml/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ package object xml {
def schema_of_xml(ds: Dataset[String], options: Map[String, String] = Map.empty): StructType =
InferSchema.infer(ds.rdd, XmlOptions(options))

/**
* Infers the schema of XML documents as strings.
*
* @param df one-column DataFrame of XML strings
* @param options additional XML parsing options
* @return inferred schema for XML
*/
@Experimental
def schema_of_xml_df(df: DataFrame, options: Map[String, String] = Map.empty): StructType =
schema_of_xml(df.as[String](Encoders.STRING), options);

/**
* Infers the schema of XML documents when inputs are arrays of strings, each an XML doc.
*
Expand Down
8 changes: 4 additions & 4 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
""".stripMargin
import spark.implicits._
val df = spark.createDataFrame(Seq((8, xmlData))).toDF("number", "payload")
val xmlSchema = schema_of_xml(df.select("payload").as[String])
val xmlSchema = schema_of_xml_df(df.select("payload"))
val expectedSchema = df.schema.add("decoded", xmlSchema)
val result = df.withColumn("decoded", from_xml(df.col("payload"), xmlSchema))

Expand Down Expand Up @@ -1179,7 +1179,7 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
""".stripMargin
import spark.implicits._
val df = spark.createDataFrame(Seq((8, xmlData))).toDF("number", "payload")
val xmlSchema = schema_of_xml(df.select("payload").as[String])
val xmlSchema = schema_of_xml_df(df.select("payload"))
val result = df.withColumn("decoded", from_xml(df.col("payload"), xmlSchema))
assert(result.select("decoded._corrupt_record").head().getString(0).nonEmpty)
}
Expand All @@ -1192,7 +1192,7 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
""".stripMargin
import spark.implicits._
val df = spark.createDataFrame(Seq((8, xmlData))).toDF("number", "payload")
val xmlSchema = schema_of_xml(df.select("payload").as[String])
val xmlSchema = schema_of_xml_df(df.select("payload"))
val result = from_xml_string(xmlData, xmlSchema)

assert(result.getString(0) === "bar")
Expand All @@ -1214,7 +1214,7 @@ final class XmlSuite extends FunSuite with BeforeAndAfterAll {
""".stripMargin
import spark.implicits._
val dfNoError = spark.createDataFrame(Seq((8, xmlDataNoError))).toDF("number", "payload")
val xmlSchema = schema_of_xml(dfNoError.select("payload").as[String])
val xmlSchema = schema_of_xml_df(dfNoError.select("payload"))
val df = spark.createDataFrame(Seq((8, xmlData))).toDF("number", "payload")
val result = df.withColumn("decoded", from_xml(df.col("payload"), xmlSchema))
assert(result.select("decoded").head().get(0) === null)
Expand Down

0 comments on commit 27dce95

Please sign in to comment.