Skip to content

Commit

Permalink
[SPARK-33836][SS][PYTHON] Expose DataStreamReader.table and DataStrea…
Browse files Browse the repository at this point in the history
…mWriter.toTable

### What changes were proposed in this pull request?

This PR proposes to expose `DataStreamReader.table` (SPARK-32885) and `DataStreamWriter.toTable` (SPARK-32896) to PySpark, which are the only way to read and write with table in Structured Streaming.

### Why are the changes needed?

Please refer SPARK-32885 and SPARK-32896 for rationalizations of these public APIs. This PR only exposes them to PySpark.

### Does this PR introduce _any_ user-facing change?

Yes, PySpark users will be able to read and write with table in Structured Streaming query.

### How was this patch tested?

Manually tested.

> v1 table

>> create table A and ingest to the table A

```
spark.sql("""
create table table_pyspark_parquet (
    value long,
    `timestamp` timestamp
) USING parquet
""")
df = spark.readStream.format('rate').option('rowsPerSecond', 100).load()
query = df.writeStream.toTable('table_pyspark_parquet', checkpointLocation='/tmp/checkpoint5')
query.lastProgress
query.stop()
```

>> read table A and ingest to the table B which doesn't exist

```
df2 = spark.readStream.table('table_pyspark_parquet')
query2 = df2.writeStream.toTable('table_pyspark_parquet_nonexist', format='parquet', checkpointLocation='/tmp/checkpoint2')
query2.lastProgress
query2.stop()
```

>> select tables

```
spark.sql("DESCRIBE TABLE table_pyspark_parquet").show()
spark.sql("SELECT * FROM table_pyspark_parquet").show()

spark.sql("DESCRIBE TABLE table_pyspark_parquet_nonexist").show()
spark.sql("SELECT * FROM table_pyspark_parquet_nonexist").show()
```

> v2 table (leveraging Apache Iceberg as it provides V2 table and custom catalog as well)

>> create table A and ingest to the table A

```
spark.sql("""
create table iceberg_catalog.default.table_pyspark_v2table (
    value long,
    `timestamp` timestamp
) USING iceberg
""")
df = spark.readStream.format('rate').option('rowsPerSecond', 100).load()
query = df.select('value', 'timestamp').writeStream.toTable('iceberg_catalog.default.table_pyspark_v2table', checkpointLocation='/tmp/checkpoint_v2table_1')
query.lastProgress
query.stop()
```

>> ingest to the non-exist table B

```
df2 = spark.readStream.format('rate').option('rowsPerSecond', 100).load()
query2 = df2.select('value', 'timestamp').writeStream.toTable('iceberg_catalog.default.table_pyspark_v2table_nonexist', checkpointLocation='/tmp/checkpoint_v2table_2')
query2.lastProgress
query2.stop()
```

>> ingest to the non-exist table C partitioned by `value % 10`

```
df3 = spark.readStream.format('rate').option('rowsPerSecond', 100).load()
df3a = df3.selectExpr('value', 'timestamp', 'value % 10 AS partition').repartition('partition')
query3 = df3a.writeStream.partitionBy('partition').toTable('iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned', checkpointLocation='/tmp/checkpoint_v2table_3')
query3.lastProgress
query3.stop()
```

>> select tables

```
spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table").show()
spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table").show()

spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table_nonexist").show()
spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table_nonexist").show()

spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned").show()
spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned").show()
```

Closes #30835 from HeartSaVioR/SPARK-33836.

Lead-authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Co-authored-by: Jungtaek Lim (HeartSaVioR) <kabhwan.opensource@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
HeartSaVioR authored and HyukjinKwon committed Dec 21, 2020
1 parent b313a1e commit 8d4d433
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 2 deletions.
105 changes: 103 additions & 2 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,36 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
else:
raise TypeError("path can be only a single string")

def table(self, tableName):
"""Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should
support streaming mode.
.. versionadded:: 3.1.0
Parameters
----------
tableName : str
string, for the name of the table.
Returns
--------
:class:`DataFrame`
Notes
-----
This API is evolving.
Examples
--------
>>> csv_sdf = spark.readStream.table('input_table') # doctest: +SKIP
>>> csv_sdf.isStreaming # doctest: +SKIP
True
"""
if isinstance(tableName, str):
return self._df(self._jreader.table(tableName))
else:
raise TypeError("tableName can be only a single string")


class DataStreamWriter(object):
"""
Expand Down Expand Up @@ -987,7 +1017,7 @@ def outputMode(self, outputMode):
* `append`: Only the new rows in the streaming DataFrame/Dataset will be written to
the sink
* `complete`: All the rows in the streaming DataFrame/Dataset will be written to the sink
every time these is some updates
every time these are some updates
* `update`: only the rows that were updated in the streaming DataFrame/Dataset will be
written to the sink every time there are some updates. If the query doesn't contain
aggregations, it will be equivalent to `append` mode.
Expand Down Expand Up @@ -1416,7 +1446,7 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query
* `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the
sink
* `complete`: All the rows in the streaming DataFrame/Dataset will be written to the
sink every time these is some updates
sink every time these are some updates
* `update`: only the rows that were updated in the streaming DataFrame/Dataset will be
written to the sink every time there are some updates. If the query doesn't contain
aggregations, it will be equivalent to `append` mode.
Expand Down Expand Up @@ -1464,6 +1494,77 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query
else:
return self._sq(self._jwrite.start(path))

def toTable(self, tableName, format=None, outputMode=None, partitionBy=None, queryName=None,
**options):
"""
Starts the execution of the streaming query, which will continually output results to the
given table as new data arrives.
A new table will be created if the table not exists. The returned
:class:`StreamingQuery` object can be used to interact with the stream.
.. versionadded:: 3.1.0
Parameters
----------
tableName : str
string, for the name of the table.
format : str, optional
the format used to save.
outputMode : str, optional
specifies how data of a streaming DataFrame/Dataset is written to a
streaming sink.
* `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the
sink
* `complete`: All the rows in the streaming DataFrame/Dataset will be written to the
sink every time these are some updates
* `update`: only the rows that were updated in the streaming DataFrame/Dataset will be
written to the sink every time there are some updates. If the query doesn't contain
aggregations, it will be equivalent to `append` mode.
partitionBy : str or list, optional
names of partitioning columns
queryName : str, optional
unique name for the query
**options : dict
All other string options. You may want to provide a `checkpointLocation`.
Notes
-----
This API is evolving.
Examples
--------
>>> sq = sdf.writeStream.format('parquet').queryName('this_query').option(
... 'checkpointLocation', '/tmp/checkpoint').toTable('output_table') # doctest: +SKIP
>>> sq.isActive # doctest: +SKIP
True
>>> sq.name # doctest: +SKIP
'this_query'
>>> sq.stop() # doctest: +SKIP
>>> sq.isActive # doctest: +SKIP
False
>>> sq = sdf.writeStream.trigger(processingTime='5 seconds').toTable(
... 'output_table', queryName='that_query', outputMode="append", format='parquet',
... checkpointLocation='/tmp/checkpoint') # doctest: +SKIP
>>> sq.name # doctest: +SKIP
'that_query'
>>> sq.isActive # doctest: +SKIP
True
>>> sq.stop() # doctest: +SKIP
"""
# TODO(SPARK-33659): document the current behavior for DataStreamWriter.toTable API
self.options(**options)
if outputMode is not None:
self.outputMode(outputMode)
if partitionBy is not None:
self.partitionBy(partitionBy)
if format is not None:
self.format(format)
if queryName is not None:
self.queryName(queryName)
return self._sq(self._jwrite.toTable(tableName))


def _test():
import doctest
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/streaming.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class DataStreamReader(OptionUtils):
recursiveFileLookup: Optional[Union[bool, str]] = ...,
unescapedQuoteHandling: Optional[str] = ...,
) -> DataFrame: ...
def table(self, tableName: str) -> DataFrame: ...

class DataStreamWriter:
def __init__(self, df: DataFrame) -> None: ...
Expand Down Expand Up @@ -185,3 +186,12 @@ class DataStreamWriter:
def foreachBatch(
self, func: Callable[[DataFrame, int], None]
) -> DataStreamWriter: ...
def toTable(
self,
tableName: str,
format: Optional[str] = ...,
outputMode: Optional[str] = ...,
partitionBy: Optional[Union[str, List[str]]] = ...,
queryName: Optional[str] = ...,
**options: OptionalPrimitiveType
) -> StreamingQuery: ...
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import shutil
import tempfile
import time
from random import randint

from pyspark.sql import Row
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -569,6 +571,30 @@ def collectBatch(df, id):
if q:
q.stop()

def test_streaming_read_from_table(self):
input_table_name = "sample_input_table_%d" % randint(0, 100000000)
self.spark.sql("CREATE TABLE %s (value string) USING parquet" % input_table_name)
self.spark.sql("INSERT INTO %s VALUES ('aaa'), ('bbb'), ('ccc')" % input_table_name)
df = self.spark.readStream.table(input_table_name)
self.assertTrue(df.isStreaming)
q = df.writeStream.format('memory').queryName('this_query').start()
q.processAllAvailable()
q.stop()
result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect()
self.assertEqual([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')], result)

def test_streaming_write_to_table(self):
output_table_name = "sample_output_table_%d" % randint(0, 100000000)
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.toTable(output_table_name, format='parquet', checkpointLocation=tmpPath)
self.assertTrue(q.isActive)
time.sleep(3)
q.stop()
result = self.spark.sql("SELECT value FROM %s" % output_table_name).collect()
self.assertTrue(len(result) > 0)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 8d4d433

Please sign in to comment.