diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6d3bef007e6e4..d0cc4f79957a9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -20,10 +20,8 @@ package org.apache.spark.util import java.io._ import java.net._ import java.nio.ByteBuffer -import java.util.{Properties, Locale, Random, UUID} -import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} - -import org.apache.log4j.PropertyConfigurator +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.{Locale, Properties, Random, UUID} import scala.collection.JavaConversions._ import scala.collection.Map @@ -37,9 +35,10 @@ import com.google.common.io.{ByteStreams, Files} import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ -import tachyon.client.{TachyonFile,TachyonFS} +import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -952,8 +951,8 @@ private[spark] object Utils extends Logging { */ def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { val trace = Thread.currentThread.getStackTrace() - .filterNot { ste:StackTraceElement => - // When running under some profilers, the current stack trace might contain some bogus + .filterNot { ste:StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus // frames. This is intended to ensure that we don't crash in these situations by // ignoring any frames that we can't examine. (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) @@ -1354,7 +1353,7 @@ private[spark] object Utils extends Logging { } } - /** + /** * Execute the given block, logging and re-throwing any uncaught exception. * This is particularly useful for wrapping code that runs in a thread, to ensure * that exceptions are printed, and to avoid having to catch Throwable. @@ -1551,7 +1550,6 @@ private[spark] object Utils extends Logging { "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") PropertyConfigurator.configure(pro) } - } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a75af94d29303..4889fea24af6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -75,6 +75,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => setConf(key, value) + case _ => + } + /** * :: DeveloperApi :: * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 031b695169cea..3429fbad024c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -48,43 +48,35 @@ case class SetCommand( extends LeafNode with Command with Logging { override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { - // Set value for key k. - case (Some(k), Some(v)) => - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + // Configures the deprecated "mapred.reduce.tasks" property. + case (Some(SQLConf.Deprecated.MAPRED_REDUCE_TASKS), Some(v)) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") - } else { - context.setConf(k, v) - Array(s"$k=$v") - } - - // Query the value bound to key k. - case (Some(k), _) => - // TODO (lian) This is just a workaround to make the Simba ODBC driver work. - // Should remove this once we get the ODBC driver updated. - if (k == "-v") { - val hiveJars = Seq( - "hive-exec-0.12.0.jar", - "hive-service-0.12.0.jar", - "hive-common-0.12.0.jar", - "hive-hwi-0.12.0.jar", - "hive-0.12.0.jar").mkString(":") - - Array( - "system:java.class.path=" + hiveJars, - "system:sun.java.command=shark.SharkServer2") - } - else { - Array(s"$k=${context.getConf(k, "")}") - } - - // Query all key-value pairs that are set in the SQLConf of the context. - case (None, None) => - context.getAllConfs.map { case (k, v) => - s"$k=$v" - }.toSeq + context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) + Seq(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + + // Configures a single property. + case (Some(k), Some(v)) => + context.setConf(k, v) + Seq(s"$k=$v") + + // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different + // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties + // while "SET -v" returns all properties.) + case (Some("-v") | None, None) => + context.getAllConfs.map { case (k, v) => s"$k=$v" }.toSeq + + // Queries the deprecated "mapred.reduce.tasks" property. + case (Some(SQLConf.Deprecated.MAPRED_REDUCE_TASKS), None) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}") + + // Queries a single property. + case (Some(k), None) => + Seq(s"$k=${context.getConf(k, "")}") case _ => throw new IllegalArgumentException() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f2389f8f0591e..4fdfc2ba1ba6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,9 +17,18 @@ package org.apache.spark.sql.test +import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends SQLContext( + new SparkContext( + "local[2]", + "TestSQLContext", + new SparkConf().set("spark.sql.testkey", "true"))) { + + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 584f71b3c13d5..60701f0e154f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.scalatest.FunSuiteLike + import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest { +class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" val testVal = "test.val.0" + test("propagate from spark conf") { + // We create a new context here to avoid order dependence with other tests that might call + // clear(). + val newContext = new SQLContext(TestSQLContext.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") == "true") + } + test("programmatic ways of basic setting and getting") { clear() assert(getAllConfs.size === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 42923b6a288d9..c6b790a4b6a23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -63,8 +63,7 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean, - binaryField: Array[Byte]) + booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -75,13 +74,14 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, - binaryField: Array[Byte], array: Seq[Int], arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], data: Data) +case class BinaryData(binaryData: Array[Byte]) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -117,26 +117,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray)) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - } + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } - test("Treat binary as string") { + test("read/write binary data") { + // Since equality for Array[Byte] is broken we test this separately. + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) + parquetFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + } + + ignore("Treat binary as string") { val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString // Create the test file. @@ -151,37 +151,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA StructField("c2", BinaryType, false) :: Nil) val schemaRDD1 = applySchema(rowRDD, schema) schemaRDD1.saveAsParquetFile(path) - val resultWithBinary = parquetFile(path).collect - range.foreach { - i => - assert(resultWithBinary(i).getInt(0) === i) - assert(resultWithBinary(i)(1) === s"val_$i".getBytes) - } - - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - // This ParquetRelation always use Parquet types to derive output. - val parquetRelation = new ParquetRelation( - path.toString, - Some(TestSQLContext.sparkContext.hadoopConfiguration), - TestSQLContext) { - override val output = - ParquetTypesConverter.convertToAttributes( - ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, - TestSQLContext.isParquetBinaryAsString) - } - val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) - val resultWithString = schemaRDD.collect - range.foreach { - i => - assert(resultWithString(i).getInt(0) === i) - assert(resultWithString(i)(1) === s"val_$i") - } + checkAnswer( + parquetFile(path).select('c1, 'c2.cast(StringType)), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - schemaRDD.registerTempTable("tmp") + setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + parquetFile(path).printSchema() checkAnswer( - sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + parquetFile(path), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) + // Set it back. TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) @@ -284,34 +263,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) + val data = sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), Data((0 until x), Nested(x, s"$x")))) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) - assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) - assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) - assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) - } + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } test("self-join parquet files") { @@ -408,23 +372,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } - test("Saving case class RDD table to file and reading it back in") { - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") - } - Utils.deleteRecursively(file) - } - test("Read a parquet file instead of a directory") { val file = getTempFilePath("parquet") val path = file.toString @@ -457,32 +404,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) - assert(rdd_copy1(0).apply(0) === 1) - assert(rdd_copy1(0).apply(1) === "val_1") - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) assert(rdd_copy2.size === 200) - assert(rdd_copy2(0).apply(0) === 1) - assert(rdd_copy2(0).apply(1) === "val_1") - assert(rdd_copy2(99).apply(0) === 100) - assert(rdd_copy2(99).apply(1) === "val_100") - assert(rdd_copy2(100).apply(0) === 1) - assert(rdd_copy2(100).apply(1) === "val_1") Utils.deleteRecursively(dirname) } test("Insert (appending) to same table via Scala API") { - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) assert(double_rdd.size === 30) - for(i <- (0 to 14)) { - assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") - } + // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) ParquetTestData.writeFile() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index cadf7aaf42157..161f8c6199b08 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} @@ -38,24 +35,12 @@ private[hive] object HiveThriftServer2 extends Logging { def main(args: Array[String]) { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") - if (!optionsProcessor.process(args)) { System.exit(-1) } - val ss = new SessionState(new HiveConf(classOf[SessionState])) - - // Set all properties specified via command line. - val hiveConf: HiveConf = ss.getConf - hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logDebug(s"HiveConf var: $k=$v") - } - - SessionState.start(ss) - logInfo("Starting SparkContext") SparkSQLEnv.init() - SessionState.start(ss) Runtime.getRuntime.addShutdownHook( new Thread() { @@ -67,7 +52,7 @@ private[hive] object HiveThriftServer2 extends Logging { try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) - server.init(hiveConf) + server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 42cbf363b274f..94ec9978af85f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConversions._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory -import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.cli._ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext @@ -57,6 +57,15 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) initCompositeService(hiveConf) } + + override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = { + getInfoType match { + case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version) + case _ => super.getInfo(sessionHandle, getInfoType) + } + } } private[thriftserver] trait ReflectedCompositeService { this: AbstractService => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 582264eb59f83..e07402c56c5b9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import org.apache.hadoop.hive.ql.session.SessionState +import scala.collection.JavaConversions._ -import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.Logging +import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -33,14 +32,18 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (hiveContext == null) { - sparkContext = new SparkContext(new SparkConf() - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + val sparkConf = new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .set("spark.sql.hive.version", "0.12.0") + sparkContext = new SparkContext(sparkConf) sparkContext.addSparkListener(new StatsReportListener()) + hiveContext = new HiveContext(sparkContext) - hiveContext = new HiveContext(sparkContext) { - @transient override lazy val sessionState = SessionState.get() - @transient override lazy val hiveconf = sessionState.getConf + if (log.isDebugEnabled) { + hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + logDebug(s"HiveConf var: $k=$v") + } } } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3475c2c9db080..e8ffbc5b954d4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,15 +18,13 @@ package org.apache.spark.sql.hive.thriftserver +import java.io._ + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} -import java.io._ -import java.util.concurrent.atomic.AtomicInteger - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -53,17 +51,19 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin.split("\\s+").toSeq ++ extraArgs } - // AtomicInteger is needed because stderr and stdout of the forked process are handled in - // different threads. - val next = new AtomicInteger(0) + var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) val buffer = new ArrayBuffer[String]() + val lock = new Object - def captureOutput(source: String)(line: String) { + def captureOutput(source: String)(line: String): Unit = lock.synchronized { buffer += s"$source> $line" - if (line.contains(expectedAnswers(next.get()))) { - if (next.incrementAndGet() == expectedAnswers.size) { + // If we haven't found all expected answers and another expected answer comes up... + if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + next += 1 + // If all expected answers have been found... + if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } } @@ -73,11 +73,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val process = (Process(command) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL CLI process exit value: $exitValue") - } - try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => @@ -88,14 +83,15 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |======================= |Spark SQL CLI command line: ${command.mkString(" ")} | - |Executed query ${next.get()} "${queries(next.get())}", - |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout. + |Executed query $next "${queries(next)}", + |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | |${buffer.mkString("\n")} |=========================== |End CliSuite failure output |=========================== """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() @@ -107,7 +103,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute)( + runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" @@ -118,7 +114,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { -> "Time taken: ", "SELECT COUNT(*) FROM hive_test;" -> "5", - "DROP TABLE hive_test" + "DROP TABLE hive_test;" -> "Time taken: " ) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 38977ff162097..08b4cc1c42c31 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -17,32 +17,39 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} -import scala.sys.process.{Process, ProcessLogger} - import java.io.File import java.net.ServerSocket import java.sql.{DriverManager, Statement} import java.util.concurrent.TimeoutException +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.sql.catalyst.util.getTempFilePath /** * Tests for the HiveThriftServer2 using JDBC. + * + * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be + * rebuilt after changing HiveThriftServer2 related code. */ class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - private val listeningHost = "localhost" - private val listeningPort = { + def randomListeningPort = { // Let the system to choose a random available port to avoid collision with other parallel // builds. val socket = new ServerSocket(0) @@ -51,61 +58,91 @@ class HiveThriftServer2Suite extends FunSuite with Logging { port } - private val warehousePath = getTempFilePath("warehouse") - private val metastorePath = getTempFilePath("metastore") - private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val port = randomListeningPort + + startThriftServer(port, serverStartTimeout) { + val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/" + val user = System.getProperty("user.name") + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() - def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) { - val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + try { + f(statement) + } finally { + statement.close() + connection.close() + } + } + } + + def withCLIServiceClient( + serverStartTimeout: FiniteDuration = 1.minute)( + f: ThriftCLIServiceClient => Unit) { + val port = randomListeningPort + + startThriftServer(port) { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", port) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + + try { + f(client) + } finally { + transport.close() + } + } + } + def startThriftServer( + port: Int, + serverStartTimeout: FiniteDuration = 1.minute)( + f: => Unit) { + val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + + val warehousePath = getTempFilePath("warehouse") + val metastorePath = getTempFilePath("metastore") + val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" val command = - s"""$serverScript + s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"} + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port """.stripMargin.split("\\s+").toSeq - val serverStarted = Promise[Unit]() + val serverRunning = Promise[Unit]() val buffer = new ArrayBuffer[String]() + val lock = new Object - def captureOutput(source: String)(line: String) { + def captureOutput(source: String)(line: String): Unit = lock.synchronized { buffer += s"$source> $line" if (line.contains("ThriftBinaryCLIService listening on")) { - serverStarted.success(()) + serverRunning.success(()) } } - val process = Process(command).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL Thrift server process exit value: $exitValue") - } + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + val env = Seq("SPARK_TESTING" -> "0") - val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" - val user = System.getProperty("user.name") + val process = Process(command, None, env: _*).run( + ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) try { - Await.result(serverStarted.future, timeout) - - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } + Await.result(serverRunning.future, serverStartTimeout) + f } catch { case cause: Exception => cause match { case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $timeout", cause) + logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) case _ => } logError( @@ -114,14 +151,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |HiveThriftServer2Suite failure output |===================================== |HiveThriftServer2 command line: ${command.mkString(" ")} - |JDBC URI: $jdbcUri - |User: $user + |Binding port: $port + |System user: ${System.getProperty("user.name")} | |${buffer.mkString("\n")} |========================================= |End HiveThriftServer2Suite failure output |========================================= """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() @@ -130,14 +168,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("Test JDBC query execution") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - val queries = Seq( - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test", - "CACHE TABLE test") + val queries = + s"""SET spark.sql.shuffle.partitions=3; + |CREATE TABLE test(key INT, val STRING); + |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test; + |CACHE TABLE test; + """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty) queries.foreach(statement.execute) @@ -150,7 +190,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("SPARK-3004 regression: result set containing NULL") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource( "data/files/small_kv_with_null.txt") @@ -173,4 +213,31 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(!resultSet.next()) } } + + test("GetInfo Thrift API") { + withCLIServiceClient() { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(SparkContext.SPARK_VERSION, "Spark version shouldn't be \"Unknown\"") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + } + } + } + + test("Checks Hive version") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=0.12.0") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d9b2bc7348ad2..b44a94c6aed9a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -222,17 +222,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } /** - * SQLConf and HiveConf contracts: when the hive session is first initialized, params in - * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside sql() will be set in the SQLConf *as well as* - * in the HiveConf. + * SQLConf and HiveConf contracts: + * + * 1. reuse existing started SessionState if any + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. */ - @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = { - val ss = new SessionState(hiveconf) - setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - ss - } + @transient protected[hive] lazy val (hiveconf, sessionState) = + Option(SessionState.get()) + .orElse { + val newState = new SessionState(new HiveConf(classOf[SessionState])) + // Only starts newly created `SessionState` instance. Any existing `SessionState` instance + // returned by `SessionState.get()` must be the most recently started one. + SessionState.start(newState) + Some(newState) + } + .map { state => + setConf(state.getConf.getAllProperties) + if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") + if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") + (state.getConf, state) + } + .get sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") @@ -290,6 +302,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { SessionState.start(sessionState) + // Makes sure the session represented by the `sessionState` field is activated. This implies + // Spark SQL Hive support uses a single `SessionState` for all Hive operations and breaks + // session isolation under multi-user scenarios (i.e. HiveThriftServer2). + // TODO Fix session isolation + if (SessionState.get() != sessionState) { + SessionState.start(sessionState) + } + proc match { case driver: Driver => driver.init() @@ -306,7 +326,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { driver.destroy() results case _ => - sessionState.out.println(tokens(0) + " " + cmd_1) + if (sessionState.out != null) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } Seq(proc.run(cmd_1).getResponseCode.toString) } } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 8bb2216b7b4f4..094e58e9863c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ object TestHive - extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) /** * A locally running test instance of Spark's Hive execution engine. @@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index d2587431951b8..cdf984420782b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.util.Try -import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -313,10 +312,10 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") test("case sensitivity: registered table") { - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: - TestData(2, "str2") :: Nil) + TestData(2, "str2") :: Nil).toSchemaRDD testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { @@ -327,7 +326,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -467,10 +466,10 @@ class HiveQuerySuite extends HiveComparisonTest { } // Describe a registered temporary table. - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: - TestData(1, "str2") :: Nil) + TestData(1, "str2") :: Nil).toSchemaRDD testData.registerTempTable("test_describe_commands2") assertResult( @@ -520,10 +519,15 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - + val KV = "([^=]+)=([^=]*)".r + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value + case Row(KV(key, value)) => key -> value + }.toSet clear() - // "set" itself returns all config variables currently specified in SQLConf. + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -532,46 +536,21 @@ class HiveQuerySuite extends HiveComparisonTest { } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) - } + assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { sql(s"SET").collect().map(_.getString(0)) } - - // "set key" - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey").collect().map(_.getString(0)) - } - - assertResult(Array(s"$nonexistentKey=")) { - sql(s"SET $nonexistentKey").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + collectResults(sql("SET -v")) } - // Assert that sql() should have the same effects as sql() by repeating the above using sql(). - clear() - assert(sql("SET").collect().size == 0) - - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) - } - - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(s"$testKey=$testVal")) { - sql("SET").collect().map(_.getString(0)) - } - - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { - sql("SET").collect().map(_.getString(0)) - } - - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey").collect().map(_.getString(0)) + // "SET key" + assertResult(Set(testKey -> testVal)) { + collectResults(sql(s"SET $testKey")) } assertResult(Array(s"$nonexistentKey=")) {