diff --git a/api_core/google/api_core/bidi.py b/api_core/google/api_core/bidi.py index f73c7c9dfabc..b171a4112a31 100644 --- a/api_core/google/api_core/bidi.py +++ b/api_core/google/api_core/bidi.py @@ -561,6 +561,10 @@ def _recv(self): def recv(self): return self._recoverable(self._recv) + def close(self): + self._finalize(None) + super(ResumableBidiRpc, self).close() + @property def is_active(self): """bool: True if this stream is currently open and active.""" @@ -698,7 +702,11 @@ def stop(self): if self._thread is not None: # Resume the thread to wake it up in case it is sleeping. self.resume() - self._thread.join() + # The daemonized thread may itself block, so don't wait + # for it longer than a second. + self._thread.join(1.0) + if self._thread.is_alive(): # pragma: NO COVER + _LOGGER.warning("Background thread did not exit.") self._thread = None diff --git a/api_core/tests/unit/test_bidi.py b/api_core/tests/unit/test_bidi.py index 4d185d3158e4..52215cbde22f 100644 --- a/api_core/tests/unit/test_bidi.py +++ b/api_core/tests/unit/test_bidi.py @@ -597,6 +597,31 @@ def test_recv_failure(self): assert bidi_rpc.is_active is False assert call.cancelled is True + def test_close(self): + call = mock.create_autospec(_CallAndFuture, instance=True) + + def cancel_side_effect(): + call.is_active.return_value = False + + call.cancel.side_effect = cancel_side_effect + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + bidi_rpc.open() + + bidi_rpc.close() + + should_recover.assert_not_called() + call.cancel.assert_called_once() + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + # ensure the request queue was signaled to stop. + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + assert bidi_rpc._finalized + def test_reopen_failure_on_rpc_restart(self): error1 = ValueError("1") error2 = ValueError("2")