From 265e20711510aafc956552e9684ab7a39074bf70 Mon Sep 17 00:00:00 2001 From: Vikash Singh <3116482+vi3k6i5@users.noreply.github.com> Date: Wed, 20 Apr 2022 12:33:33 +0530 Subject: [PATCH] fix: add NOT_FOUND error check in __exit__ method of SessionCheckout. (#718) * fix: Inside SnapshotCheckout __exit__ block check if NotFound exception was raised for the session and create new session if needed * test: add test for SnapshotCheckout __exit__ checks * refactor: lint fixes * test: add test case for NotFound Error in SessionCheckout context but unrelated to Sessions --- google/cloud/spanner_v1/database.py | 6 +++ tests/unit/test_database.py | 61 ++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 5dc41e525e..90916bc710 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -868,6 +868,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" + if isinstance(exc_val, NotFound): + # If NotFound exception occurs inside the with block + # then we validate if the session still exists. + if not self._session.exists(): + self._session = self._database._pool._new_session() + self._session.create() self._database._pool.put(self._session) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 9cabc99945..bd47a2ac31 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -17,7 +17,6 @@ import mock from google.api_core import gapic_v1 - from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry @@ -1792,6 +1791,66 @@ class Testing(Exception): self.assertIs(pool._session, session) + def test_context_mgr_session_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=False) + pool = database._pool = _Pool() + new_session = _Session(database, name="session-2") + new_session.create = mock.MagicMock(return_value=[]) + pool._new_session = mock.MagicMock(return_value=new_session) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Session not found") + # Assert that session-1 was removed from pool and new session was added. + self.assertEqual(pool._session, new_session) + + def test_context_mgr_table_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=True) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Table not found") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + + def test_context_mgr_unknown_error(self): + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + pool.put(session) + checkout = self._make_one(database) + + class Testing(Exception): + pass + + self.assertEqual(pool._session, session) + with self.assertRaises(Testing): + with checkout as _: + raise Testing("Unknown error.") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + class TestBatchSnapshot(_BaseTest): TABLE = "table_name"