From 65e9c73958f7e28f55ae40c0c1136a5c28e2a66b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 9 Jan 2015 15:52:00 -0800 Subject: [PATCH 1/2] Revert all changes since applying a given schema has not been testd. --- .../apache/spark/sql/parquet/newParquet.scala | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 506be8ccde6b3..2e0c6c51c00e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -22,37 +22,37 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + +import org.apache.spark.sql.{SQLConf, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SQLConf, SQLContext} import scala.collection.JavaConversions._ - /** * Allows creation of parquet based tables using the syntax * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option * required is `path`, which should be the location of a collection of, optionally partitioned, * parquet files. */ -class DefaultSource extends SchemaRelationProvider { +class DefaultSource extends RelationProvider { /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType]): BaseRelation = { + parameters: Map[String, String]): BaseRelation = { val path = parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - ParquetRelation2(path, schema)(sqlContext) + ParquetRelation2(path)(sqlContext) } } @@ -82,9 +82,7 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files: * discovery. */ @DeveloperApi -case class ParquetRelation2( - path: String, - userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) +case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) extends CatalystScan with Logging { def sparkContext = sqlContext.sparkContext @@ -135,13 +133,12 @@ case class ParquetRelation2( override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum - val dataSchema = userSpecifiedSchema.getOrElse( - StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.isParquetBinaryAsString)) - ) + val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. + ParquetTypesConverter.readSchemaFromFile( + partitions.head.files.head.getPath, + Some(sparkContext.hadoopConfiguration), + sqlContext.isParquetBinaryAsString)) + val dataIncludesKey = partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) From 38f634e978e73faac5173fb9fbb22ed48c8c7f3f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 9 Jan 2015 15:57:13 -0800 Subject: [PATCH 2/2] Remove Option from createRelation. --- .../apache/spark/sql/json/JSONRelation.scala | 19 +++++++++--- .../org/apache/spark/sql/sources/ddl.scala | 31 +++++++++++++------ .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 6 ++-- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 47da6f5e5237b..a9a6696cb15e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ -private[sql] class DefaultSource extends SchemaRelationProvider { - /** Returns a new base relation with the given parameters. */ +private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { + + /** Returns a new base relation with the parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio, None)(sqlContext) + } + + /** Returns a new base relation with the given schema and parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation = { + schema: StructType): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, schema)(sqlContext) + JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 991e8d888c110..7f0fd73aa721c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing( sys.error(s"Failed to load class for data source: $provider") } } - val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema) + + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.") + } + } + case None => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.RelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.") + } + } } sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 97157c868cc90..990f7e0e74bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -68,7 +68,7 @@ trait SchemaRelationProvider { def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation + schema: StructType): BaseRelation } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index a0e0172a4a548..605190f5ae6a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -45,7 +45,7 @@ class AllDataTypesScanSource extends SchemaRelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation = { + schema: StructType): BaseRelation = { AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) } } @@ -53,10 +53,10 @@ class AllDataTypesScanSource extends SchemaRelationProvider { case class AllDataTypesScan( from: Int, to: Int, - userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) extends TableScan { - override def schema = userSpecifiedSchema.get + override def schema = userSpecifiedSchema override def buildScan() = { sqlContext.sparkContext.parallelize(from to to).map { i =>