Skip to content

Commit

Permalink
[SPARK-11967][SQL] Use varargs for multiple paths in DataFrameReader.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Nov 24, 2015
1 parent c7f95df commit bd9a538
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 19 deletions.
19 changes: 8 additions & 11 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def options(self, **options):
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
:param path: optional string for file-system backed data sources.
:param path: optional string or a list of string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
Expand All @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options):
... opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
>>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
... 'python/test_support/sql/people1.json'])
>>> df.dtypes
Expand All @@ -129,13 +130,7 @@ def load(self, path=None, format=None, schema=None, **options):
self.schema(schema)
self.options(**options)
if path is not None:
if type(path) == list:
paths = path
gateway = self._sqlContext._sc._gateway
jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths)
return self._df(self._jreader.load(jpaths))
else:
return self._df(self._jreader.load(path))
return self._df(self._jreader.load(path))
else:
return self._df(self._jreader.load())

Expand Down Expand Up @@ -173,7 +168,7 @@ def json(self, path, schema=None):
"""
if schema is not None:
self.schema(schema)
if isinstance(path, basestring):
if isinstance(path, basestring) or type(path) == list:
return self._df(self._jreader.json(path))
elif isinstance(path, RDD):
return self._df(self._jreader.json(path._jrdd))
Expand Down Expand Up @@ -205,16 +200,18 @@ def parquet(self, *paths):

@ignore_unicode_prefix
@since(1.6)
def text(self, path):
def text(self, paths):
"""Loads a text file and returns a [[DataFrame]] with a single string column named "text".
Each line in the text file is a new row in the resulting DataFrame.
:param paths: string, or list of strings, for input path(s).
>>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
return self._df(self._jreader.text(path))
return self._df(self._jreader.text(paths))

@since(1.5)
def orc(self, path):
Expand Down
36 changes: 29 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.StringUtils

import org.apache.spark.{Logging, Partition}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.SqlParser
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.types.StructType
import org.apache.spark.{Logging, Partition}
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}

/**
* :: Experimental ::
Expand Down Expand Up @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.4.0
*/
// TODO: Remove this one in Spark 2.0.
def load(path: String): DataFrame = {
option("path", path).load()
}
Expand All @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.6.0
*/
def load(paths: Array[String]): DataFrame = {
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
}

Expand Down Expand Up @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
*
* @param path input path
* @since 1.4.0
*/
// TODO: Remove this one in Spark 2.0.
def json(path: String): DataFrame = format("json").load(path)

/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
* </li>
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
*
* @since 1.6.0
*/
def json(paths: String*): DataFrame = format("json").load(paths : _*)

/**
* Loads an `JavaRDD[String]` storing JSON objects (one object per record) and
* returns the result as a [[DataFrame]].
Expand Down Expand Up @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* sqlContext.read().text("/path/to/spark/README.md")
* }}}
*
* @param path input path
* @param paths input path
* @since 1.6.0
*/
def text(path: String): DataFrame = format("text").load(path)
@scala.annotation.varargs
def text(paths: String*): DataFrame = format("text").load(paths : _*)

///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,27 @@ public void pivot() {
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}

public void testGenericLoad() {
DataFrame df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());

DataFrame df2 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}

@Test
public void testTextLoad() {
DataFrame df1 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());

DataFrame df2 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}
}
1 change: 1 addition & 0 deletions sql/core/src/test/resources/text-suite2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is another file for testing multi path loading.
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val dir2 = new File(dir, "dir2").getCanonicalPath
df2.write.format("json").save(dir2)

checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)),
checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
Row(1, 22) :: Row(2, 23) :: Nil)

checkAnswer(sqlContext.read.format("json").load(dir1),
Expand Down

0 comments on commit bd9a538

Please sign in to comment.