From 38f634e978e73faac5173fb9fbb22ed48c8c7f3f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 9 Jan 2015 15:57:13 -0800 Subject: [PATCH] Remove Option from createRelation. --- .../apache/spark/sql/json/JSONRelation.scala | 19 +++++++++--- .../org/apache/spark/sql/sources/ddl.scala | 31 +++++++++++++------ .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 6 ++-- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 47da6f5e5237b..a9a6696cb15e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ -private[sql] class DefaultSource extends SchemaRelationProvider { - /** Returns a new base relation with the given parameters. */ +private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { + + /** Returns a new base relation with the parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio, None)(sqlContext) + } + + /** Returns a new base relation with the given schema and parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation = { + schema: StructType): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, schema)(sqlContext) + JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 991e8d888c110..7f0fd73aa721c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing( sys.error(s"Failed to load class for data source: $provider") } } - val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema) + + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.") + } + } + case None => { + clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.RelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case _ => + sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.") + } + } } sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 97157c868cc90..990f7e0e74bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -68,7 +68,7 @@ trait SchemaRelationProvider { def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation + schema: StructType): BaseRelation } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index a0e0172a4a548..605190f5ae6a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -45,7 +45,7 @@ class AllDataTypesScanSource extends SchemaRelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType]): BaseRelation = { + schema: StructType): BaseRelation = { AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) } } @@ -53,10 +53,10 @@ class AllDataTypesScanSource extends SchemaRelationProvider { case class AllDataTypesScan( from: Int, to: Int, - userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) extends TableScan { - override def schema = userSpecifiedSchema.get + override def schema = userSpecifiedSchema override def buildScan() = { sqlContext.sparkContext.parallelize(from to to).map { i =>