From 191bc0d18e721251b1a79390533199794a56ea21 Mon Sep 17 00:00:00 2001 From: Andre Schumacher Date: Sat, 24 May 2014 11:49:09 +0300 Subject: [PATCH] Changing to Seq for ArrayType, refactoring SQLParser for nested field extension --- .../apache/spark/sql/catalyst/SqlParser.scala | 112 +++++++++--------- .../catalyst/expressions/complexTypes.scala | 4 +- .../catalyst/plans/logical/LogicalPlan.scala | 76 ++---------- .../spark/sql/parquet/ParquetConverter.scala | 13 +- .../spark/sql/parquet/ParquetTestData.scala | 4 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 97 ++++++++++----- 6 files changed, 144 insertions(+), 162 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d2baf09074799..2ad2d04af5704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = - allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - protected class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('.') | elem('_') | elem('[') | elem(']') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - } - - override val lexical = new SqlLexical + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { this.getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword]) - - /** Generate all variations of upper and lower case of a given string */ - private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } + .map(_.invoke(this).asInstanceOf[Keyword].str) - lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str)) - - lexical.delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" - ) + override val lexical = new SqlLexical(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -383,14 +332,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { + expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | cast | "(" ~> expression <~ ")" | - "[" ~> literal <~ "]" | function | "-" ~> literal ^^ UnaryMinus | ident ^^ UnresolvedAttribute | @@ -400,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType } + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]" + ) + + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') | elem('.') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 285118fc81ded..37ccb965feb87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -50,7 +50,9 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { null } else { if (child.dataType.isInstanceOf[ArrayType]) { - val baseValue = value.asInstanceOf[Array[_]] + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + val baseValue = value.asInstanceOf[Seq[_]] val o = key.asInstanceOf[Int] if (o >= baseValue.size || o < 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7ff2287fbee03..76459f49cae02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -58,53 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { * can contain ordinal expressions, such as `field[i][j][k]...`. */ def resolve(name: String): Option[NamedExpression] = { - def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = { - val (exp, t) = expType - val ordinalRegExp = """(\[(\d+|\w+)\])""".r - val fieldName = if (ordinalRegExp.findFirstIn(field).isDefined) { - field.substring(0, field.indexOf("[")) - } else { - field - } - t match { - case ArrayType(elementType) => - val ordinals = ordinalRegExp - .findAllIn(field) - .matchData - .map(_.group(2)) - (ordinals.foldLeft(exp)((v1: Expression, v2: String) => - GetItem(v1, Literal(v2.toInt))), elementType) - case MapType(keyType, valueType) => - val ordinals = ordinalRegExp - .findAllIn(field) - .matchData - .map(_.group(2)) - // TODO: we should recover the JVM type of keyType to match the - // actual type of the key?! should we restrict ourselves to NativeType? - (ordinals.foldLeft(exp)((v1: Expression, v2: String) => - GetItem(v1, Literal(v2, keyType))), valueType) - case StructType(fields) => - val structField = fields - .find(_.name == fieldName) - if (!structField.isDefined) { - throw new TreeNodeException( - this, s"Trying to resolve Attribute but field ${fieldName} is not defined") - } - structField.get.dataType match { - case ArrayType(elementType) => - val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2)) - (ordinals.foldLeft( - GetField(exp, fieldName).asInstanceOf[Expression])((v1: Expression, v2: String) => - GetItem(v1, Literal(v2.toInt))), - elementType) - case _ => - (GetField(exp, fieldName), structField.get.dataType) - } - case _ => - expType - } - } - + // TODO: extend SqlParser to handle field expressions val parts = name.split("\\.") // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the @@ -124,33 +78,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { remainingParts.head } if (option.name == relevantRemaining) (option, remainingParts.tail.toList) :: Nil else Nil*/ + // If the first part of the desired name matches a qualifier for this possible match, drop it. + /* TODO: from rebase! + val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts + if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + */ } options.distinct match { - case (a, Nil) :: Nil => { - a.dataType match { - case ArrayType(_) | MapType(_, _) => - val expression = expandFunc((a: Expression, a.dataType), name)._1 - Some(Alias(expression, name)()) - case _ => Some(a) - } - } // One match, no nested fields, use it. + case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. case (a, nestedFields) :: Nil => a.dataType match { case StructType(fields) => - // this is compatibility reasons with earlier code! - // TODO: why only nestedFields and not parts? - // check for absence of nested arrays so there are only fields - if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[(\\d+|\\w+)\\]+"))) { - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - } else { - val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1 - Some(Alias(expression, nestedFields.last)()) - } - case _ => - val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1 - Some(Alias(expression, nestedFields.last)()) + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) + case _ => None // Don't know how to resolve these field references } case Nil => None // No matches. case ambiguousReferences => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index f694e59252fe3..27c4c2ac76487 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -63,7 +63,8 @@ private[sql] object CatalystConverter { val MAP_VALUE_SCHEMA_NAME = "value" val MAP_SCHEMA_NAME = "map" - type ArrayScalaType[T] = Array[T] + // TODO: consider using Array[T] for arrays to avoid boxing of primitive types + type ArrayScalaType[T] = Seq[T] type StructScalaType[T] = Seq[T] type MapScalaType[K, V] = Map[K, V] @@ -426,7 +427,7 @@ private[parquet] class CatalystArrayConverter( override def end(): Unit = { assert(parent != null) // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray) + parent.updateField(index, buffer.toArray.toSeq) clearBuffer() } } @@ -451,8 +452,7 @@ private[parquet] class CatalystNativeArrayConverter( type NativeType = elementType.JvmType - private var buffer: CatalystConverter.ArrayScalaType[NativeType] = - elementType.classTag.newArray(capacity) + private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) private var elements: Int = 0 @@ -526,15 +526,14 @@ private[parquet] class CatalystNativeArrayConverter( // here we need to make sure to use ArrayScalaType parent.updateField( index, - buffer.slice(0, elements)) + buffer.slice(0, elements).toSeq) clearBuffer() } private def checkGrowBuffer(): Unit = { if (elements >= capacity) { val newCapacity = 2 * capacity - val tmp: CatalystConverter.ArrayScalaType[NativeType] = - elementType.classTag.newArray(newCapacity) + val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) Array.copy(buffer, 0, tmp, 0, capacity) buffer = tmp capacity = newCapacity diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d3449a6cbf77e..a11e19f3b6e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -377,13 +377,13 @@ private[sql] object ParquetTestData { val map2 = r1.addGroup(2) val keyValue3 = map2.addGroup(0) // TODO: currently only string key type supported - keyValue3.add(0, "7") + keyValue3.add(0, "seven") val valueGroup1 = keyValue3.addGroup(1) valueGroup1.add(0, 42.toLong) valueGroup1.add(1, "the answer") val keyValue4 = map2.addGroup(0) // TODO: currently only string key type supported - keyValue4.add(0, "8") + keyValue4.add(0, "eight") val valueGroup2 = keyValue4.addGroup(1) valueGroup2.add(0, 49.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 39a4e25ae7ae3..3cf7b0f10d09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -32,14 +32,16 @@ import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.TestData import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} import org.apache.spark.util.Utils // Implicits @@ -71,7 +73,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA var testRDD: SchemaRDD = null + // TODO: remove this once SqlParser can parse nested select statements + var nestedParserSqlContext: NestedParserSQLContext = null + override def beforeAll() { + nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() ParquetTestData.writeNestedFile1() @@ -221,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) source_rdd.registerAsTable("source") - val dest_rdd = createParquetFile(dirname.toString, ("key", IntegerType), ("value", StringType)) + val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) dest_rdd.registerAsTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() @@ -474,11 +480,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection in addressbook") { - val data = TestSQLContext + val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD data.registerAsTable("data") - val tmp = sql("SELECT owner, contacts[1].name FROM data").collect() + val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val tmp = query.collect() assert(tmp.size === 2) assert(tmp(0).size === 2) assert(tmp(0)(0) === "Julien Le Dem") @@ -488,21 +495,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Simple query on nested int data") { - val data = TestSQLContext + val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD data.registerAsTable("data") - val result1 = sql("SELECT entries[0].value FROM data").collect() + val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === 2.5) - val result2 = sql("SELECT entries[0] FROM data").collect() + val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() assert(result2.size === 1) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] assert(subresult1.size === 2) assert(subresult1(0) === 2.5) assert(subresult1(1) === false) - val result3 = sql("SELECT outerouter FROM data").collect() + val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] @@ -515,19 +522,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("nested structs") { - val data = TestSQLContext + val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD data.registerAsTable("data") - val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === false) - val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() assert(result2.size === 1) assert(result2(0).size === 1) assert(result2(0)(0) === true) - val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() assert(result3.size === 1) assert(result3(0).size === 1) assert(result3(0)(0) === false) @@ -546,30 +553,30 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] .getOrElse("key2", 0) === 2) - val result2 = sql("SELECT data1[key1] FROM mapTable").collect() + val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() assert(result2(0)(0) === 1) } test("map with struct values") { - val data = TestSQLContext + val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD data.registerAsTable("mapTable") - val result1 = sql("SELECT data2 FROM mapTable").collect() + val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("7", null) + .getOrElse("seven", null) assert(entry1 != null) assert(entry1(0) === 42) assert(entry1(1) === "the answer") val entry2 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("8", null) + .getOrElse("eight", null) assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result2 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM mapTable").collect() + val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() assert(result2.size === 1) assert(result2(0)(0) === 42.toLong) assert(result2(0)(1) === "the answer") @@ -580,15 +587,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // has no effect in this test case val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) - val result = TestSQLContext + val result = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD result.saveAsParquetFile(tmpdir.toString) - TestSQLContext + nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD .registerAsTable("tmpcopy") - val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) assert(tmpdata(0)(0) === "Julien Le Dem") @@ -599,34 +606,34 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Writing out Map and reading it back in") { - val data = TestSQLContext + val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) data.saveAsParquetFile(tmpdir.toString) - TestSQLContext + nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD .registerAsTable("tmpmapcopy") - val result1 = sql("SELECT data1[key2] FROM tmpmapcopy").collect() + val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) - val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() + val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() assert(result2.size === 1) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("7", null) + .getOrElse("seven", null) assert(entry1 != null) assert(entry1(0) === 42) assert(entry1(1) === "the answer") val entry2 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("8", null) + .getOrElse("eight", null) assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result3 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM tmpmapcopy").collect() + val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() assert(result3.size === 1) assert(result3(0)(0) === 42.toLong) assert(result3(0)(1) === "the answer") @@ -774,3 +781,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(mapResult2(2) === 1.3f) } } + +// TODO: the code below is needed temporarily until the standard parser is able to parse +// nested field expressions correctly +class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { + override protected[sql] val parser = new NestedSqlParser() +} + +class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { + override def identChar = letter | elem('_') + delimiters += (".") +} + +class NestedSqlParser extends SqlParser { + override val lexical = new NestedSqlLexical(reservedWords) + + override protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression <~ "]" ^^ { + case base ~ _ ~ ordinal => GetItem(base, ordinal) + } | + expression ~ "." ~ ident ^^ { + case base ~ _ ~ fieldName => GetField(base, fieldName) + } | + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal +}