diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index d05c9652753e0..3299e86b85941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} import org.apache.spark.sql.catalyst.types.StringType /** @@ -26,23 +26,25 @@ import org.apache.spark.sql.catalyst.types.StringType */ abstract class Command extends LeafNode { self: Product => - def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this + def output: Seq[Attribute] = Seq.empty } /** * Returned for commands supported by a given parser, but not catalyst. In general these are DDL * commands that are passed directly to another system. */ -case class NativeCommand(cmd: String) extends Command +case class NativeCommand(cmd: String) extends Command { + override def output = + Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) +} /** * Commands of the form "SET (key) (= value)". */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - AttributeReference("key", StringType, nullable = false)(), - AttributeReference("value", StringType, nullable = false)() - ) + BoundReference(0, AttributeReference("key", StringType, nullable = false)()), + BoundReference(1, AttributeReference("value", StringType, nullable = false)())) } /** @@ -50,11 +52,11 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman * actually performing the execution. */ case class ExplainCommand(plan: LogicalPlan) extends Command { - override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) + override def output = + Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) } /** * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0cada785b6630..1f67c80e54906 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -161,7 +161,7 @@ class FilterPushdownSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } - + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 264192ed1aa26..585bfcbd75ecb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryColumnarTableScan @@ -147,14 +147,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { - val result = new SchemaRDD(this, parseSql(sqlText)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = @@ -280,22 +273,6 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { - case SetCommand(key, value) => - // Only this case needs to be executed eagerly. The other cases will - // be taken care of when the actual results are being extracted. - // In the case of HiveContext, sqlConf is overridden to also pass the - // pair into its HiveConf. - if (key.isDefined && value.isDefined) { - set(key.get, value.get) - } - // It doesn't matter what we return here, since this is only used - // to force the evaluation to happen eagerly. To query the results, - // one must use SchemaRDD operations to extract them. - emptyResult - case _ => executedPlan.execute() - } - lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -303,12 +280,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = { - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => executedPlan.execute() - } - } + lazy val toRdd: RDD[Row] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -330,7 +302,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * TODO: We only support primitive types, add support for nested types. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - val schema = rdd.first.map { case (fieldName, obj) => + val schema = rdd.first().map { case (fieldName, obj) => val dataType = obj.getClass match { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 3a895e15a4508..e4cbb037709ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -50,6 +50,14 @@ private[sql] trait SchemaRDDLike { @DeveloperApi lazy val queryExecution = sqlContext.executePlan(logicalPlan) + logicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + queryExecution.toRdd + case _ => + } + override def toString = s"""${super.toString} |== Query Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index be26d19e66862..67830d422a52f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -22,45 +22,69 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +trait PhysicalCommand { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffect` so + * that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] +} + /** * :: DeveloperApi :: */ @DeveloperApi -case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) - (@transient context: SQLContext) extends LeafNode { - def execute(): RDD[Row] = (key, value) match { - // Set value for key k; the action itself would - // have been performed in QueryExecution eagerly. - case (Some(k), Some(v)) => context.emptyResult +case class SetCommandPhysical( + key: Option[String], value: Option[String], output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with PhysicalCommand { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + // Set value for key k. + case (Some(k), Some(v)) => + context.set(k, v) + Array.empty[(String, String)] + // Query the value bound to key k. - case (Some(k), None) => - val resultString = context.getOption(k) match { - case Some(v) => s"$k=$v" - case None => s"$k is undefined" - } - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + case (Some(k), _) => + Array(k -> context.getOption(k).getOrElse("")) + // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - val pairs = context.getAll - val rows = pairs.map { case (k, v) => - new GenericRow(Array[Any](s"$k=$v")) - }.toSeq - // Assume config parameters can fit into one split (machine) ;) - context.sparkContext.parallelize(rows, 1) - // The only other case is invalid semantics and is impossible. - case _ => context.emptyResult + context.getAll + + case _ => + throw new IllegalArgumentException() } + + def execute(): RDD[Row] = { + val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil } /** * :: DeveloperApi :: */ @DeveloperApi -case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) - (@transient context: SQLContext) extends UnaryNode { +case class ExplainCommandPhysical( + child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends UnaryNode with PhysicalCommand { + + // Actually "EXPLAIN" command doesn't cause any side effect. + override protected[sql] lazy val sideEffectResult: Seq[String] = child.toString.split("\n") + def execute(): RDD[Row] = { - val planString = new GenericRow(Array[Any](child.toString)) - context.sparkContext.parallelize(Seq(planString)) + val explanation = sideEffectResult.mkString("\n") + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) } override def otherCopyArgs = context :: Nil @@ -71,18 +95,19 @@ case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) */ @DeveloperApi case class CacheCommandPhysical(tableName: String, doCache: Boolean)(@transient context: SQLContext) - extends LeafNode { + extends LeafNode with PhysicalCommand { - lazy val commandSideEffect = { + override protected[sql] lazy val sideEffectResult = { if (doCache) { context.cacheTable(tableName) } else { context.uncacheTable(tableName) } + Seq.empty[Any] } override def execute(): RDD[Row] = { - commandSideEffect + sideEffectResult context.emptyResult } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c1fc99f077431..e9360b0fc7910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -141,7 +141,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2))) } - + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { (3, "C"), (4, "D"))) } - + test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), @@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest { (2, "ABC"), (3, null))) } - + test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), @@ -382,25 +382,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Seq(testKey, testVal), + Seq(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey is undefined")) + Seq(Seq(nonexistentKey, "")) ) clear() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 64978215542ec..b3c87513a59c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -34,7 +34,6 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution._ @@ -71,14 +70,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. */ - def hiveql(hqlQuery: String): SchemaRDD = { - val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but does not perform any execution. - result.queryExecution.toRdd - result - } + def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) /** An alias for `hiveql`. */ def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) @@ -164,7 +156,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Runs the specified SQL query using Hive. */ - protected def runSqlHive(sql: String): Seq[String] = { + protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 val results = runHive(sql, 100000) // It is very confusing when you only get back some of the results... @@ -228,6 +220,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override val strategies: Seq[Strategy] = Seq( CommandStrategy(self), + HiveCommandStrategy(self), TakeOrdered, ParquetOperations, HiveTableScans, @@ -251,25 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = { - def processCmd(cmd: String): RDD[Row] = { - val output = runSqlHive(cmd) - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - } - - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => analyzed match { - case NativeCommand(cmd) => processCmd(cmd) - case _ => executedPlan.execute().map(_.copy()) - } - } - } + override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -297,7 +272,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { @@ -313,10 +288,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. */ - def stringResult(): Seq[String] = analyzed match { - case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => executePlan(plan).toString.split("\n") - case query => + def stringResult(): Seq[String] = executedPlan match { + case command: PhysicalCommand => + command.sideEffectResult.map(_.toString) + + case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) @@ -327,8 +303,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { - case _: NativeCommand => "" - case _: SetCommand => "" + case _: NativeCommand => "" + case _: SetCommand => "" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8b51957162e04..e0ea7826fa70d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -75,4 +75,12 @@ private[hive] trait HiveStrategies { Nil } } + + case class HiveCommandStrategy(context: HiveContext) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.NativeCommand(sql) => + NativeCommandPhysical(sql, plan.output)(context) :: Nil + case _ => Nil + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index 29b4b9b006e45..e5fbd1c6f15f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -32,14 +32,15 @@ import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.{TaskContext, SparkException} import org.apache.spark.util.MutablePair +import org.apache.spark.{TaskContext, SparkException} /* Implicits */ import scala.collection.JavaConversions._ @@ -57,7 +58,7 @@ case class HiveTableScan( attributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Option[Expression])( - @transient val sc: HiveContext) + @transient val context: HiveContext) extends LeafNode with HiveInspectors { @@ -75,7 +76,7 @@ case class HiveTableScan( } @transient - val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + val hadoopReader = new HadoopTableReader(relation.tableDesc, context) /** * The hive object inspector for this table, which can be used to extract values from the @@ -156,7 +157,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) } - addColumnMetadataToConf(sc.hiveconf) + addColumnMetadataToConf(context.hiveconf) @transient def inputRdd = if (!relation.hiveQlTable.isPartitioned) { @@ -428,3 +429,26 @@ case class InsertIntoHiveTable( sc.sparkContext.makeRDD(Nil, 1) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class NativeCommandPhysical( + sql: String, output: Seq[Attribute])( + @transient context: HiveContext) + extends LeafNode with PhysicalCommand { + + override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) + + override def execute(): RDD[spark.sql.Row] = { + if (sideEffectResult.size == 0) { + context.emptyResult + } else { + val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) + context.sparkContext.parallelize(rows, 1) + } + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6c239b02ed09a..1c0b86e9a90a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.catalyst.plans.logical.ExplainCommand import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -166,7 +167,7 @@ class HiveQuerySuite extends HiveComparisonTest { hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") val rdd = hql("explain select key, count(value) from src group by key") assert(rdd.collect().size == 1) - assert(rdd.toString.contains("ExplainCommand")) + assert(rdd.toString.contains(ExplainCommand.getClass.getSimpleName)) assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0, "actual contents of the result should be the plans of the query to be explained") TestHive.reset() @@ -195,9 +196,11 @@ class HiveQuerySuite extends HiveComparisonTest { test("SET commands semantics for a HiveContext") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" - var testVal = "test.val.0" + val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) => + key -> value + } clear() @@ -206,41 +209,51 @@ class HiveQuerySuite extends HiveComparisonTest { // "set key=val" hql(s"SET $testKey=$testVal") - assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql("SET").collect()) + } hql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(hql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(hql("SET").collect()) + } // "set key" - assert(fromRows(hql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey").collect()) + } + + assertResult(Array(testKey -> "")) { + rowsToPairs(hql(s"SET $nonexistentKey").collect()) + } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("set").collect().size == 0) sql(s"SET $testKey=$testVal") - assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql("SET").collect()) + } sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(sql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(sql("SET").collect()) + } - assert(fromRows(sql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey").collect()) + } + + assertResult(Array(testKey -> "")) { + rowsToPairs(sql(s"SET $nonexistentKey").collect()) + } + + clear() } // Put tests that depend on specific Hive settings before these last two test,