diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b499436a49c18..1f814c560fd64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,8 +20,16 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, StructType} +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees +import scala.util.matching.Regex +import org.apache.spark.sql.catalyst.types.ArrayType +import org.apache.spark.sql.catalyst.expressions.GetField +import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.catalyst.types.MapType +import scala.Some +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.GetItem abstract class LogicalPlan extends QueryPlan[LogicalPlan] { self: Product => @@ -60,20 +68,32 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def resolve(name: String): Option[NamedExpression] = { def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = { val (exp, t) = expType - val ordinalRegExp = """(\[(\d+)\])""".r - val fieldName = if (field.matches("\\w*(\\[\\d\\])+")) { + val ordinalRegExp = """(\[(\d+|\w+)\])""".r + val fieldName = if (ordinalRegExp.findFirstIn(field).isDefined) { field.substring(0, field.indexOf("[")) } else { field } t match { case ArrayType(elementType) => - val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2)) + val ordinals = ordinalRegExp + .findAllIn(field) + .matchData + .map(_.group(2)) (ordinals.foldLeft(exp)((v1: Expression, v2: String) => GetItem(v1, Literal(v2.toInt))), elementType) + case MapType(keyType, valueType) => + val ordinals = ordinalRegExp + .findAllIn(field) + .matchData + .map(_.group(2)) + // TODO: we should recover the JVM type of valueType to match the + // actual type of the key?! should we restrict ourselves to NativeType? + (ordinals.foldLeft(exp)((v1: Expression, v2: String) => + GetItem(v1, Literal(v2, keyType))), valueType) case StructType(fields) => - // Note: this only works if we are not on the top-level! - val structField = fields.find(_.name == fieldName) + val structField = fields + .find(_.name == fieldName) if (!structField.isDefined) { throw new TreeNodeException( this, s"Trying to resolve Attribute but field ${fieldName} is not defined") @@ -106,7 +126,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { // TODO from rebase! /*val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts val relevantRemaining = - if (remainingParts.head.matches("\\w*\\[(\\d+)\\]")) { // array field name + if (remainingParts.head.matches("\\w*\\[(\\d+|\\w+)\\]")) { // array field name remainingParts.head.substring(0, remainingParts.head.indexOf("[")) } else { remainingParts.head @@ -117,7 +137,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { options.distinct match { case (a, Nil) :: Nil => { a.dataType match { - case ArrayType(elementType) => + case ArrayType(_) | MapType(_, _) => val expression = expandFunc((a: Expression, a.dataType), name)._1 Some(Alias(expression, name)()) case _ => Some(a) @@ -130,7 +150,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { // this is compatibility reasons with earlier code! // TODO: why only nestedFields and not parts? // check for absence of nested arrays so there are only fields - if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[\\d+\\]+"))) { + if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[(\\d+|\\w+)\\]+"))) { Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) } else { val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 94a7ab719789f..37306b4a26078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -455,8 +455,7 @@ private[parquet] class CatalystMapConverter( // TODO: think about reusing the buffer override def end(): Unit = { - assert(!isRootConverter) - parent.updateField(index, map) + parent.updateField(index, map.toMap) } override def getConverter(fieldIndex: Int): Converter = keyValueConverter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index c0ae738418202..e3957732eb341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -179,7 +179,7 @@ private[sql] object ParquetTestData { |} |required group data2 { |repeated group map { - |required int32 key; + |required binary key; |optional group value { |required int64 payload1; |optional binary payload2; @@ -366,12 +366,14 @@ private[sql] object ParquetTestData { keyValue2.add(1, 2) val map2 = r1.addGroup(2) val keyValue3 = map2.addGroup(0) - keyValue3.add(0, 7) + // TODO: currently only string key type supported + keyValue3.add(0, "7") val valueGroup1 = keyValue3.addGroup(1) valueGroup1.add(0, 42.toLong) valueGroup1.add(1, "the answer") val keyValue4 = map2.addGroup(0) - keyValue4.add(0, 8) + // TODO: currently only string key type supported + keyValue4.add(0, "8") val valueGroup2 = keyValue4.addGroup(1) valueGroup2.add(0, 49.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 4dcf6d472bd76..ea940184ca4e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -518,7 +518,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("simple map") { - implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[String, Int]] + implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Int]] val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD @@ -527,36 +527,30 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 1) assert(result1(0)(0).toMap.getOrElse("key1", 0) === 1) assert(result1(0)(0).toMap.getOrElse("key2", 0) === 2) + val result2 = sql("SELECT data1[key1] FROM mapTable").collect() + assert(result2(0)(0) === 1) } test("map with struct values") { - //implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row] - implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[Int, Row]] - //val data = TestSQLContext - // .parquetFile(ParquetTestData.testNestedDir4.toString) - // .toSchemaRDD + implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Row]] val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD data.registerAsTable("mapTable") - - /*ParquetTestData.readNestedFile( - ParquetTestData.testNestedDir4, - ParquetTestData.testNestedSchema4) - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD - .collect()*/ val result1 = sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) - val entry1 = result1(0)(0).getOrElse(7, null) + val entry1 = result1(0)(0).getOrElse("7", null) assert(entry1 != null) assert(entry1(0) === 42) assert(entry1(1) === "the answer") - val entry2 = result1(0)(0).getOrElse(8, null) + val entry2 = result1(0)(0).getOrElse("8", null) assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) + val result2 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM mapTable").collect() + assert(result2.size === 1) + assert(result2(0)(0) === 42.toLong) + assert(result2(0)(1) === "the answer") } /**