From 8d4d43319191ada0e07e3b27abe41929aa3eefe5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 21 Dec 2020 19:42:59 +0900 Subject: [PATCH] [SPARK-33836][SS][PYTHON] Expose DataStreamReader.table and DataStreamWriter.toTable ### What changes were proposed in this pull request? This PR proposes to expose `DataStreamReader.table` (SPARK-32885) and `DataStreamWriter.toTable` (SPARK-32896) to PySpark, which are the only way to read and write with table in Structured Streaming. ### Why are the changes needed? Please refer SPARK-32885 and SPARK-32896 for rationalizations of these public APIs. This PR only exposes them to PySpark. ### Does this PR introduce _any_ user-facing change? Yes, PySpark users will be able to read and write with table in Structured Streaming query. ### How was this patch tested? Manually tested. > v1 table >> create table A and ingest to the table A ``` spark.sql(""" create table table_pyspark_parquet ( value long, `timestamp` timestamp ) USING parquet """) df = spark.readStream.format('rate').option('rowsPerSecond', 100).load() query = df.writeStream.toTable('table_pyspark_parquet', checkpointLocation='/tmp/checkpoint5') query.lastProgress query.stop() ``` >> read table A and ingest to the table B which doesn't exist ``` df2 = spark.readStream.table('table_pyspark_parquet') query2 = df2.writeStream.toTable('table_pyspark_parquet_nonexist', format='parquet', checkpointLocation='/tmp/checkpoint2') query2.lastProgress query2.stop() ``` >> select tables ``` spark.sql("DESCRIBE TABLE table_pyspark_parquet").show() spark.sql("SELECT * FROM table_pyspark_parquet").show() spark.sql("DESCRIBE TABLE table_pyspark_parquet_nonexist").show() spark.sql("SELECT * FROM table_pyspark_parquet_nonexist").show() ``` > v2 table (leveraging Apache Iceberg as it provides V2 table and custom catalog as well) >> create table A and ingest to the table A ``` spark.sql(""" create table iceberg_catalog.default.table_pyspark_v2table ( value long, `timestamp` timestamp ) USING iceberg """) df = spark.readStream.format('rate').option('rowsPerSecond', 100).load() query = df.select('value', 'timestamp').writeStream.toTable('iceberg_catalog.default.table_pyspark_v2table', checkpointLocation='/tmp/checkpoint_v2table_1') query.lastProgress query.stop() ``` >> ingest to the non-exist table B ``` df2 = spark.readStream.format('rate').option('rowsPerSecond', 100).load() query2 = df2.select('value', 'timestamp').writeStream.toTable('iceberg_catalog.default.table_pyspark_v2table_nonexist', checkpointLocation='/tmp/checkpoint_v2table_2') query2.lastProgress query2.stop() ``` >> ingest to the non-exist table C partitioned by `value % 10` ``` df3 = spark.readStream.format('rate').option('rowsPerSecond', 100).load() df3a = df3.selectExpr('value', 'timestamp', 'value % 10 AS partition').repartition('partition') query3 = df3a.writeStream.partitionBy('partition').toTable('iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned', checkpointLocation='/tmp/checkpoint_v2table_3') query3.lastProgress query3.stop() ``` >> select tables ``` spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table").show() spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table").show() spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table_nonexist").show() spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table_nonexist").show() spark.sql("DESCRIBE TABLE iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned").show() spark.sql("SELECT * FROM iceberg_catalog.default.table_pyspark_v2table_nonexist_partitioned").show() ``` Closes #30835 from HeartSaVioR/SPARK-33836. Lead-authored-by: Jungtaek Lim Co-authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: HyukjinKwon --- python/pyspark/sql/streaming.py | 105 ++++++++++++++++++++- python/pyspark/sql/streaming.pyi | 10 ++ python/pyspark/sql/tests/test_streaming.py | 26 +++++ 3 files changed, 139 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 365b5f38694a7..2c9c1f06274ce 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -953,6 +953,36 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non else: raise TypeError("path can be only a single string") + def table(self, tableName): + """Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should + support streaming mode. + + .. versionadded:: 3.1.0 + + Parameters + ---------- + tableName : str + string, for the name of the table. + + Returns + -------- + :class:`DataFrame` + + Notes + ----- + This API is evolving. + + Examples + -------- + >>> csv_sdf = spark.readStream.table('input_table') # doctest: +SKIP + >>> csv_sdf.isStreaming # doctest: +SKIP + True + """ + if isinstance(tableName, str): + return self._df(self._jreader.table(tableName)) + else: + raise TypeError("tableName can be only a single string") + class DataStreamWriter(object): """ @@ -987,7 +1017,7 @@ def outputMode(self, outputMode): * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the sink * `complete`: All the rows in the streaming DataFrame/Dataset will be written to the sink - every time these is some updates + every time these are some updates * `update`: only the rows that were updated in the streaming DataFrame/Dataset will be written to the sink every time there are some updates. If the query doesn't contain aggregations, it will be equivalent to `append` mode. @@ -1416,7 +1446,7 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the sink * `complete`: All the rows in the streaming DataFrame/Dataset will be written to the - sink every time these is some updates + sink every time these are some updates * `update`: only the rows that were updated in the streaming DataFrame/Dataset will be written to the sink every time there are some updates. If the query doesn't contain aggregations, it will be equivalent to `append` mode. @@ -1464,6 +1494,77 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query else: return self._sq(self._jwrite.start(path)) + def toTable(self, tableName, format=None, outputMode=None, partitionBy=None, queryName=None, + **options): + """ + Starts the execution of the streaming query, which will continually output results to the + given table as new data arrives. + + A new table will be created if the table not exists. The returned + :class:`StreamingQuery` object can be used to interact with the stream. + + .. versionadded:: 3.1.0 + + Parameters + ---------- + tableName : str + string, for the name of the table. + format : str, optional + the format used to save. + outputMode : str, optional + specifies how data of a streaming DataFrame/Dataset is written to a + streaming sink. + + * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to the + sink + * `complete`: All the rows in the streaming DataFrame/Dataset will be written to the + sink every time these are some updates + * `update`: only the rows that were updated in the streaming DataFrame/Dataset will be + written to the sink every time there are some updates. If the query doesn't contain + aggregations, it will be equivalent to `append` mode. + partitionBy : str or list, optional + names of partitioning columns + queryName : str, optional + unique name for the query + **options : dict + All other string options. You may want to provide a `checkpointLocation`. + + Notes + ----- + This API is evolving. + + Examples + -------- + >>> sq = sdf.writeStream.format('parquet').queryName('this_query').option( + ... 'checkpointLocation', '/tmp/checkpoint').toTable('output_table') # doctest: +SKIP + >>> sq.isActive # doctest: +SKIP + True + >>> sq.name # doctest: +SKIP + 'this_query' + >>> sq.stop() # doctest: +SKIP + >>> sq.isActive # doctest: +SKIP + False + >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').toTable( + ... 'output_table', queryName='that_query', outputMode="append", format='parquet', + ... checkpointLocation='/tmp/checkpoint') # doctest: +SKIP + >>> sq.name # doctest: +SKIP + 'that_query' + >>> sq.isActive # doctest: +SKIP + True + >>> sq.stop() # doctest: +SKIP + """ + # TODO(SPARK-33659): document the current behavior for DataStreamWriter.toTable API + self.options(**options) + if outputMode is not None: + self.outputMode(outputMode) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if queryName is not None: + self.queryName(queryName) + return self._sq(self._jwrite.toTable(tableName)) + def _test(): import doctest diff --git a/python/pyspark/sql/streaming.pyi b/python/pyspark/sql/streaming.pyi index 829610ad3b94b..1d05483c012f1 100644 --- a/python/pyspark/sql/streaming.pyi +++ b/python/pyspark/sql/streaming.pyi @@ -151,6 +151,7 @@ class DataStreamReader(OptionUtils): recursiveFileLookup: Optional[Union[bool, str]] = ..., unescapedQuoteHandling: Optional[str] = ..., ) -> DataFrame: ... + def table(self, tableName: str) -> DataFrame: ... class DataStreamWriter: def __init__(self, df: DataFrame) -> None: ... @@ -185,3 +186,12 @@ class DataStreamWriter: def foreachBatch( self, func: Callable[[DataFrame, int], None] ) -> DataStreamWriter: ... + def toTable( + self, + tableName: str, + format: Optional[str] = ..., + outputMode: Optional[str] = ..., + partitionBy: Optional[Union[str, List[str]]] = ..., + queryName: Optional[str] = ..., + **options: OptionalPrimitiveType + ) -> StreamingQuery: ... diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 28a50f9575a0a..44bfb2a7447ca 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -19,7 +19,9 @@ import shutil import tempfile import time +from random import randint +from pyspark.sql import Row from pyspark.sql.functions import lit from pyspark.sql.types import StructType, StructField, IntegerType, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -569,6 +571,30 @@ def collectBatch(df, id): if q: q.stop() + def test_streaming_read_from_table(self): + input_table_name = "sample_input_table_%d" % randint(0, 100000000) + self.spark.sql("CREATE TABLE %s (value string) USING parquet" % input_table_name) + self.spark.sql("INSERT INTO %s VALUES ('aaa'), ('bbb'), ('ccc')" % input_table_name) + df = self.spark.readStream.table(input_table_name) + self.assertTrue(df.isStreaming) + q = df.writeStream.format('memory').queryName('this_query').start() + q.processAllAvailable() + q.stop() + result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect() + self.assertEqual([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')], result) + + def test_streaming_write_to_table(self): + output_table_name = "sample_output_table_%d" % randint(0, 100000000) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + q = df.writeStream.toTable(output_table_name, format='parquet', checkpointLocation=tmpPath) + self.assertTrue(q.isActive) + time.sleep(3) + q.stop() + result = self.spark.sql("SELECT value FROM %s" % output_table_name).collect() + self.assertTrue(len(result) > 0) + if __name__ == "__main__": import unittest