diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 66c9c674fa657..9ac1484686e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -50,12 +50,14 @@ object CatalystConverter { case _ => new CatalystArrayConverter(elementType, fieldIndex, parent) } } - case StructType(fields: Seq[StructField]) => - new CatalystGroupConverter(fields, fieldIndex, parent) - case ctype: NativeType => + case StructType(fields: Seq[StructField]) => { + new CatalystStructConverter(fields, fieldIndex, parent) + } + case ctype: NativeType => { // note: for some reason matching for StringType fails so use this ugly if instead if (ctype == StringType) new CatalystPrimitiveStringConverter(parent, fieldIndex) else new CatalystPrimitiveConverter(parent, fieldIndex) + } case _ => throw new RuntimeException( s"unable to convert datatype ${field.dataType.toString} in CatalystGroupConverter") } @@ -109,7 +111,7 @@ trait CatalystConverter { * @param schema The corresponding Catalyst schema in the form of a list of attributes. */ class CatalystGroupConverter( - private[parquet] val schema: Seq[FieldType], + protected[parquet] val schema: Seq[FieldType], protected[parquet] val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var current: ArrayBuffer[Any], @@ -277,6 +279,23 @@ class CatalystArrayConverter( } } +// this is for multi-element groups of primitive or complex types +// that have repetition level optional or required (so struct fields) +class CatalystStructConverter( + override protected[parquet] val schema: Seq[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystGroupConverter(schema, index, parent) { + + override protected[parquet] def clearBuffer(): Unit = {} + + // TODO: think about reusing the buffer + override def end(): Unit = { + assert(!isRootConverter) + parent.updateField(index, current) + } +} + // TODO: add MapConverter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 981c403ef9cf8..142c429240a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -263,13 +263,19 @@ private[sql] object ParquetTestData { val booleanNumberPairs = r1.addGroup(3) booleanNumberPairs.add("value", 2.5) booleanNumberPairs.add("truth", false) - r1.addGroup(4).addGroup(0).addGroup(0).add( + val top_level = r1.addGroup(4) + val second_level_a = top_level.addGroup(0) + val second_level_b = top_level.addGroup(0) + val third_level_aa = second_level_a.addGroup(0) + val third_level_ab = second_level_a.addGroup(0) + val third_level_c = second_level_b.addGroup(0) + third_level_aa.add( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 7) - r1.addGroup(4).addGroup(0).addGroup(0).add( + third_level_ab.add( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 8) - r1.addGroup(4).addGroup(0).addGroup(0).add( + third_level_c.add( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 9)