Skip to content

Commit

Permalink
[SPARK-43032][FOLLOWUP][SS][CONNECT] StreamingQueryManager bug fix
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

When calling `spark.streams.get(q.id)` on a stopped query q. It should return None in python and null in scala client. But right now it throws a null pointer exception. This PR fixes this issue.

### Why are the changes needed?

Bug fix

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit tests

Closes apache#42437 from WweiL/streaming-query-manager-get-bug-fix.

Authored-by: Wei Liu <wei.liu@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
WweiL authored and hvanhovell committed Aug 13, 2023
1 parent 65684d6 commit 06f09eb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging {

q.stop()
assert(!q1.isActive)

assert(spark.streams.get(q.id) == null)
}

test("streaming query listener") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3062,8 +3062,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
.asJava)

case StreamingQueryManagerCommand.CommandCase.GET_QUERY =>
val query = session.streams.get(command.getGetQuery)
respBuilder.setQuery(buildStreamingQueryInstance(query))
Option(session.streams.get(command.getGetQuery)).foreach { q =>
respBuilder.setQuery(buildStreamingQueryInstance(q))
}

case StreamingQueryManagerCommand.CommandCase.AWAIT_ANY_TERMINATION =>
if (command.getAwaitAnyTermination.hasTimeoutMs) {
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,19 @@ def _assert_exception_tree_contains_msg_default(self, exception, msg):
contains = msg in e.desc
self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)

def test_query_manager_get(self):
df = self.spark.readStream.format("rate").load()
for q in self.spark.streams.active:
q.stop()
q = df.writeStream.format("noop").start()

self.assertTrue(q.isActive)
self.assertTrue(q.id == self.spark.streams.get(q.id).id)

q.stop()

self.assertIsNone(self.spark.streams.get(q.id))

def test_query_manager_await_termination(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
for q in self.spark.streams.active:
Expand Down

0 comments on commit 06f09eb

Please sign in to comment.