Skip to content

Commit

Permalink
Adding attribute resolution for MapType
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSchumacher committed Jun 19, 2014
1 parent b539fde commit 824500c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, StructType}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import scala.util.matching.Regex
import org.apache.spark.sql.catalyst.types.ArrayType
import org.apache.spark.sql.catalyst.expressions.GetField
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.types.MapType
import scala.Some
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.GetItem

abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
self: Product =>
Expand Down Expand Up @@ -60,20 +68,32 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
def resolve(name: String): Option[NamedExpression] = {
def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = {
val (exp, t) = expType
val ordinalRegExp = """(\[(\d+)\])""".r
val fieldName = if (field.matches("\\w*(\\[\\d\\])+")) {
val ordinalRegExp = """(\[(\d+|\w+)\])""".r
val fieldName = if (ordinalRegExp.findFirstIn(field).isDefined) {
field.substring(0, field.indexOf("["))
} else {
field
}
t match {
case ArrayType(elementType) =>
val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2))
val ordinals = ordinalRegExp
.findAllIn(field)
.matchData
.map(_.group(2))
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
GetItem(v1, Literal(v2.toInt))), elementType)
case MapType(keyType, valueType) =>
val ordinals = ordinalRegExp
.findAllIn(field)
.matchData
.map(_.group(2))
// TODO: we should recover the JVM type of valueType to match the
// actual type of the key?! should we restrict ourselves to NativeType?
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
GetItem(v1, Literal(v2, keyType))), valueType)
case StructType(fields) =>
// Note: this only works if we are not on the top-level!
val structField = fields.find(_.name == fieldName)
val structField = fields
.find(_.name == fieldName)
if (!structField.isDefined) {
throw new TreeNodeException(
this, s"Trying to resolve Attribute but field ${fieldName} is not defined")
Expand Down Expand Up @@ -106,7 +126,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
// TODO from rebase!
/*val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts
val relevantRemaining =
if (remainingParts.head.matches("\\w*\\[(\\d+)\\]")) { // array field name
if (remainingParts.head.matches("\\w*\\[(\\d+|\\w+)\\]")) { // array field name
remainingParts.head.substring(0, remainingParts.head.indexOf("["))
} else {
remainingParts.head
Expand All @@ -117,7 +137,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
options.distinct match {
case (a, Nil) :: Nil => {
a.dataType match {
case ArrayType(elementType) =>
case ArrayType(_) | MapType(_, _) =>
val expression = expandFunc((a: Expression, a.dataType), name)._1
Some(Alias(expression, name)())
case _ => Some(a)
Expand All @@ -130,7 +150,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
// this is compatibility reasons with earlier code!
// TODO: why only nestedFields and not parts?
// check for absence of nested arrays so there are only fields
if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[\\d+\\]+"))) {
if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[(\\d+|\\w+)\\]+"))) {
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
} else {
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ private[parquet] class CatalystMapConverter(

// TODO: think about reusing the buffer
override def end(): Unit = {
assert(!isRootConverter)
parent.updateField(index, map)
parent.updateField(index, map.toMap)
}

override def getConverter(fieldIndex: Int): Converter = keyValueConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private[sql] object ParquetTestData {
|}
|required group data2 {
|repeated group map {
|required int32 key;
|required binary key;
|optional group value {
|required int64 payload1;
|optional binary payload2;
Expand Down Expand Up @@ -366,12 +366,14 @@ private[sql] object ParquetTestData {
keyValue2.add(1, 2)
val map2 = r1.addGroup(2)
val keyValue3 = map2.addGroup(0)
keyValue3.add(0, 7)
// TODO: currently only string key type supported
keyValue3.add(0, "7")
val valueGroup1 = keyValue3.addGroup(1)
valueGroup1.add(0, 42.toLong)
valueGroup1.add(1, "the answer")
val keyValue4 = map2.addGroup(0)
keyValue4.add(0, 8)
// TODO: currently only string key type supported
keyValue4.add(0, "8")
val valueGroup2 = keyValue4.addGroup(1)
valueGroup2.add(0, 49.toLong)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}

test("simple map") {
implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[String, Int]]
implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Int]]
val data = TestSQLContext
.parquetFile(ParquetTestData.testNestedDir4.toString)
.toSchemaRDD
Expand All @@ -527,36 +527,30 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result1.size === 1)
assert(result1(0)(0).toMap.getOrElse("key1", 0) === 1)
assert(result1(0)(0).toMap.getOrElse("key2", 0) === 2)
val result2 = sql("SELECT data1[key1] FROM mapTable").collect()
assert(result2(0)(0) === 1)
}

test("map with struct values") {
//implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[Int, Row]]
//val data = TestSQLContext
// .parquetFile(ParquetTestData.testNestedDir4.toString)
// .toSchemaRDD
implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Row]]
val data = TestSQLContext
.parquetFile(ParquetTestData.testNestedDir4.toString)
.toSchemaRDD
data.registerAsTable("mapTable")

/*ParquetTestData.readNestedFile(
ParquetTestData.testNestedDir4,
ParquetTestData.testNestedSchema4)
val result = TestSQLContext
.parquetFile(ParquetTestData.testNestedDir4.toString)
.toSchemaRDD
.collect()*/
val result1 = sql("SELECT data2 FROM mapTable").collect()
assert(result1.size === 1)
val entry1 = result1(0)(0).getOrElse(7, null)
val entry1 = result1(0)(0).getOrElse("7", null)
assert(entry1 != null)
assert(entry1(0) === 42)
assert(entry1(1) === "the answer")
val entry2 = result1(0)(0).getOrElse(8, null)
val entry2 = result1(0)(0).getOrElse("8", null)
assert(entry2 != null)
assert(entry2(0) === 49)
assert(entry2(1) === null)
val result2 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM mapTable").collect()
assert(result2.size === 1)
assert(result2(0)(0) === 42.toLong)
assert(result2(0)(1) === "the answer")
}

/**
Expand Down

0 comments on commit 824500c

Please sign in to comment.