Skip to content

Commit

Permalink
Changing to Seq for ArrayType, refactoring SQLParser for nested field…
Browse files Browse the repository at this point in the history
… extension
  • Loading branch information
AndreSchumacher committed Jun 19, 2014
1 parent cbb5793 commit 191bc0d
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected case class Keyword(str: String)

protected implicit def asParser(k: Keyword): Parser[String] =
allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

protected class SqlLexical extends StdLexical {
case class FloatLit(chars: String) extends Token {
override def toString = chars
}
override lazy val token: Parser[Token] = (
identChar ~ rep( identChar | digit ) ^^
{ case first ~ rest => processIdent(first :: rest mkString "") }
| rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
case i ~ None => NumericLit(i mkString "")
case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
}
| '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
{ case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
| '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
{ case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
| EofCh ^^^ EOF
| '\'' ~> failure("unclosed string literal")
| '\"' ~> failure("unclosed string literal")
| delim
| failure("illegal character")
)

override def identChar = letter | elem('.') | elem('_') | elem('[') | elem(']')

override def whitespace: Parser[Any] = rep(
whitespaceChar
| '/' ~ '*' ~ comment
| '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
| '#' ~ rep( chrExcept(EofCh, '\n') )
| '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
| '/' ~ '*' ~ failure("unclosed comment")
)
}

override val lexical = new SqlLexical
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

protected val ALL = Keyword("ALL")
protected val AND = Keyword("AND")
Expand Down Expand Up @@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
this.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword])

/** Generate all variations of upper and lower case of a given string */
private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
if (s == "") {
Stream(prefix)
} else {
allCaseVersions(s.tail, prefix + s.head.toLower) ++
allCaseVersions(s.tail, prefix + s.head.toUpper)
}
}
.map(_.invoke(this).asInstanceOf[Keyword].str)

lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str))

lexical.delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]"
)
override val lexical = new SqlLexical(reservedWords)

protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
Expand Down Expand Up @@ -383,14 +332,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)

protected lazy val baseExpression: PackratParser[Expression] =
expression ~ "[" ~ expression <~ "]" ^^ {
expression ~ "[" ~ expression <~ "]" ^^ {
case base ~ _ ~ ordinal => GetItem(base, ordinal)
} |
TRUE ^^^ Literal(true, BooleanType) |
FALSE ^^^ Literal(false, BooleanType) |
cast |
"(" ~> expression <~ ")" |
"[" ~> literal <~ "]" |
function |
"-" ~> literal ^^ UnaryMinus |
ident ^^ UnresolvedAttribute |
Expand All @@ -400,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val dataType: Parser[DataType] =
STRING ^^^ StringType
}

class SqlLexical(val keywords: Seq[String]) extends StdLexical {
case class FloatLit(chars: String) extends Token {
override def toString = chars
}

reserved ++= keywords.flatMap(w => allCaseVersions(w))

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]"
)

override lazy val token: Parser[Token] = (
identChar ~ rep( identChar | digit ) ^^
{ case first ~ rest => processIdent(first :: rest mkString "") }
| rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
case i ~ None => NumericLit(i mkString "")
case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
}
| '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
{ case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
| '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
{ case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
| EofCh ^^^ EOF
| '\'' ~> failure("unclosed string literal")
| '\"' ~> failure("unclosed string literal")
| delim
| failure("illegal character")
)

override def identChar = letter | elem('_') | elem('.')

override def whitespace: Parser[Any] = rep(
whitespaceChar
| '/' ~ '*' ~ comment
| '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
| '#' ~ rep( chrExcept(EofCh, '\n') )
| '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
| '/' ~ '*' ~ failure("unclosed comment")
)

/** Generate all variations of upper and lower case of a given string */
def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
if (s == "") {
Stream(prefix)
} else {
allCaseVersions(s.tail, prefix + s.head.toLower) ++
allCaseVersions(s.tail, prefix + s.head.toUpper)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
null
} else {
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = value.asInstanceOf[Array[_]]
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (o >= baseValue.size || o < 0) {
null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,53 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
* can contain ordinal expressions, such as `field[i][j][k]...`.
*/
def resolve(name: String): Option[NamedExpression] = {
def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = {
val (exp, t) = expType
val ordinalRegExp = """(\[(\d+|\w+)\])""".r
val fieldName = if (ordinalRegExp.findFirstIn(field).isDefined) {
field.substring(0, field.indexOf("["))
} else {
field
}
t match {
case ArrayType(elementType) =>
val ordinals = ordinalRegExp
.findAllIn(field)
.matchData
.map(_.group(2))
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
GetItem(v1, Literal(v2.toInt))), elementType)
case MapType(keyType, valueType) =>
val ordinals = ordinalRegExp
.findAllIn(field)
.matchData
.map(_.group(2))
// TODO: we should recover the JVM type of keyType to match the
// actual type of the key?! should we restrict ourselves to NativeType?
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
GetItem(v1, Literal(v2, keyType))), valueType)
case StructType(fields) =>
val structField = fields
.find(_.name == fieldName)
if (!structField.isDefined) {
throw new TreeNodeException(
this, s"Trying to resolve Attribute but field ${fieldName} is not defined")
}
structField.get.dataType match {
case ArrayType(elementType) =>
val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2))
(ordinals.foldLeft(
GetField(exp, fieldName).asInstanceOf[Expression])((v1: Expression, v2: String) =>
GetItem(v1, Literal(v2.toInt))),
elementType)
case _ =>
(GetField(exp, fieldName), structField.get.dataType)
}
case _ =>
expType
}
}

// TODO: extend SqlParser to handle field expressions
val parts = name.split("\\.")
// Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the
Expand All @@ -124,33 +78,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
remainingParts.head
}
if (option.name == relevantRemaining) (option, remainingParts.tail.toList) :: Nil else Nil*/
// If the first part of the desired name matches a qualifier for this possible match, drop it.
/* TODO: from rebase!
val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
*/
}

options.distinct match {
case (a, Nil) :: Nil => {
a.dataType match {
case ArrayType(_) | MapType(_, _) =>
val expression = expandFunc((a: Expression, a.dataType), name)._1
Some(Alias(expression, name)())
case _ => Some(a)
}
} // One match, no nested fields, use it.
case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
// One match, but we also need to extract the requested nested field.
case (a, nestedFields) :: Nil =>
a.dataType match {
case StructType(fields) =>
// this is compatibility reasons with earlier code!
// TODO: why only nestedFields and not parts?
// check for absence of nested arrays so there are only fields
if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[(\\d+|\\w+)\\]+"))) {
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
} else {
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
Some(Alias(expression, nestedFields.last)())
}
case _ =>
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
Some(Alias(expression, nestedFields.last)())
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case _ => None // Don't know how to resolve these field references
}
case Nil => None // No matches.
case ambiguousReferences =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ private[sql] object CatalystConverter {
val MAP_VALUE_SCHEMA_NAME = "value"
val MAP_SCHEMA_NAME = "map"

type ArrayScalaType[T] = Array[T]
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
type ArrayScalaType[T] = Seq[T]
type StructScalaType[T] = Seq[T]
type MapScalaType[K, V] = Map[K, V]

Expand Down Expand Up @@ -426,7 +427,7 @@ private[parquet] class CatalystArrayConverter(
override def end(): Unit = {
assert(parent != null)
// here we need to make sure to use ArrayScalaType
parent.updateField(index, buffer.toArray)
parent.updateField(index, buffer.toArray.toSeq)
clearBuffer()
}
}
Expand All @@ -451,8 +452,7 @@ private[parquet] class CatalystNativeArrayConverter(

type NativeType = elementType.JvmType

private var buffer: CatalystConverter.ArrayScalaType[NativeType] =
elementType.classTag.newArray(capacity)
private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity)

private var elements: Int = 0

Expand Down Expand Up @@ -526,15 +526,14 @@ private[parquet] class CatalystNativeArrayConverter(
// here we need to make sure to use ArrayScalaType
parent.updateField(
index,
buffer.slice(0, elements))
buffer.slice(0, elements).toSeq)
clearBuffer()
}

private def checkGrowBuffer(): Unit = {
if (elements >= capacity) {
val newCapacity = 2 * capacity
val tmp: CatalystConverter.ArrayScalaType[NativeType] =
elementType.classTag.newArray(newCapacity)
val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity)
Array.copy(buffer, 0, tmp, 0, capacity)
buffer = tmp
capacity = newCapacity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,13 @@ private[sql] object ParquetTestData {
val map2 = r1.addGroup(2)
val keyValue3 = map2.addGroup(0)
// TODO: currently only string key type supported
keyValue3.add(0, "7")
keyValue3.add(0, "seven")
val valueGroup1 = keyValue3.addGroup(1)
valueGroup1.add(0, 42.toLong)
valueGroup1.add(1, "the answer")
val keyValue4 = map2.addGroup(0)
// TODO: currently only string key type supported
keyValue4.add(0, "8")
keyValue4.add(0, "eight")
val valueGroup2 = keyValue4.addGroup(1)
valueGroup2.add(0, 49.toLong)

Expand Down
Loading

0 comments on commit 191bc0d

Please sign in to comment.