Skip to content

Commit

Permalink
Using assertIs in unit tests where appropriate. (googleapis#3629)
Browse files Browse the repository at this point in the history
* Using assertIs in unit tests where appropriate.

Any usage of `self.assertTrue(a is b)` has become
`self.assertIs(a, b)`.

* Converting some assertFalse(a is b) to assertIsNot(a, b).
  • Loading branch information
dhermes authored and landrito committed Aug 21, 2017
1 parent 9ce0e69 commit bbb1255
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 57 deletions.
2 changes: 1 addition & 1 deletion spanner/tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _make_one(self, session):
def test_ctor(self):
session = object()
base = self._make_one(session)
self.assertTrue(base._session is session)
self.assertIs(base._session, session)


class Test_options_with_prefix(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions spanner/tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _compare_values(self, result, source):
def test_ctor(self):
session = _Session()
base = self._make_one(session)
self.assertTrue(base._session is session)
self.assertIs(base._session, session)
self.assertEqual(len(base._mutations), 0)

def test__check_state_virtual(self):
Expand Down Expand Up @@ -177,7 +177,7 @@ def _getTargetClass(self):
def test_ctor(self):
session = _Session()
batch = self._make_one(session)
self.assertTrue(batch._session is session)
self.assertIs(batch._session, session)

def test_commit_already_committed(self):
from google.cloud.spanner.keyset import KeySet
Expand Down
16 changes: 8 additions & 8 deletions spanner/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _constructor_test_helper(self, expected_scopes, creds,
expected_creds = expected_creds or creds.with_scopes.return_value
self.assertIs(client._credentials, expected_creds)

self.assertTrue(client._credentials is expected_creds)
self.assertIs(client._credentials, expected_creds)
if expected_scopes is not None:
creds.with_scopes.assert_called_once_with(expected_scopes)

Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, *args, **kwargs):

self.assertTrue(isinstance(api, _Client))
again = client.instance_admin_api
self.assertTrue(again is api)
self.assertIs(again, api)
self.assertEqual(api.kwargs['lib_name'], 'gccl')
self.assertIs(api.kwargs['credentials'], client.credentials)

Expand All @@ -183,7 +183,7 @@ def __init__(self, *args, **kwargs):

self.assertTrue(isinstance(api, _Client))
again = client.database_admin_api
self.assertTrue(again is api)
self.assertIs(again, api)
self.assertEqual(api.kwargs['lib_name'], 'gccl')
self.assertIs(api.kwargs['credentials'], client.credentials)

Expand All @@ -202,7 +202,7 @@ def test_copy(self):
def test_credentials_property(self):
credentials = _Credentials()
client = self._make_one(project=self.PROJECT, credentials=credentials)
self.assertTrue(client.credentials is credentials)
self.assertIs(client.credentials, credentials)

def test_project_name_property(self):
credentials = _Credentials()
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_list_instance_configs_wo_paging(self):
project, page_size, options = api._listed_instance_configs
self.assertEqual(project, self.PATH)
self.assertEqual(page_size, None)
self.assertTrue(options.page_token is INITIAL_PAGE)
self.assertIs(options.page_token, INITIAL_PAGE)
self.assertEqual(
options.kwargs['metadata'],
[('google-cloud-resource-prefix', client.project_name)])
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_instance_factory_defaults(self):
self.assertIsNone(instance.configuration_name)
self.assertEqual(instance.display_name, self.INSTANCE_ID)
self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT)
self.assertTrue(instance._client is client)
self.assertIs(instance._client, client)

def test_instance_factory_explicit(self):
from google.cloud.spanner.instance import Instance
Expand All @@ -309,7 +309,7 @@ def test_instance_factory_explicit(self):
self.assertEqual(instance.configuration_name, self.CONFIGURATION_NAME)
self.assertEqual(instance.display_name, self.DISPLAY_NAME)
self.assertEqual(instance.node_count, self.NODE_COUNT)
self.assertTrue(instance._client is client)
self.assertIs(instance._client, client)

def test_list_instances_wo_paging(self):
from google.cloud._testing import _GAXPageIterator
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_list_instances_wo_paging(self):
self.assertEqual(project, self.PATH)
self.assertEqual(filter_, 'name:TEST')
self.assertEqual(page_size, None)
self.assertTrue(options.page_token is INITIAL_PAGE)
self.assertIs(options.page_token, INITIAL_PAGE)
self.assertEqual(
options.kwargs['metadata'],
[('google-cloud-resource-prefix', client.project_name)])
Expand Down
36 changes: 18 additions & 18 deletions spanner/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_ctor_defaults(self):
database = self._make_one(self.DATABASE_ID, instance)

self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertTrue(database._instance is instance)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertIsInstance(database._pool, BurstyPool)
# BurstyPool does not create sessions during 'bind()'.
Expand All @@ -61,7 +61,7 @@ def test_ctor_w_explicit_pool(self):
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertTrue(database._instance is instance)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertIs(database._pool, pool)
self.assertIs(pool._bound, database)
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_ctor_w_ddl_statements_ok(self):
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS,
pool=pool)
self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertTrue(database._instance is instance)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS)

def test_from_pb_bad_database_name(self):
Expand Down Expand Up @@ -196,10 +196,10 @@ def _mock_spanner_client(*args, **kwargs):

with _Monkey(MUT, SpannerClient=_mock_spanner_client):
api = database.spanner_api
self.assertTrue(api is _client)
self.assertIs(api, _client)
# API instance is cached
again = database.spanner_api
self.assertTrue(again is api)
self.assertIs(again, api)

def test___eq__(self):
instance = _Instance(self.INSTANCE_NAME)
Expand Down Expand Up @@ -567,8 +567,8 @@ def test_session_factory(self):
session = database.session()

self.assertTrue(isinstance(session, Session))
self.assertTrue(session.session_id is None)
self.assertTrue(session._database is database)
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)

def test_execute_sql_defaults(self):
QUERY = 'SELECT * FROM employees'
Expand Down Expand Up @@ -671,7 +671,7 @@ def test_batch(self):

checkout = database.batch()
self.assertIsInstance(checkout, BatchCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)

def test_snapshot_defaults(self):
from google.cloud.spanner.database import SnapshotCheckout
Expand All @@ -685,7 +685,7 @@ def test_snapshot_defaults(self):

checkout = database.snapshot()
self.assertIsInstance(checkout, SnapshotCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertIsNone(checkout._min_read_timestamp)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -707,7 +707,7 @@ def test_snapshot_w_read_timestamp(self):
checkout = database.snapshot(read_timestamp=now)

self.assertIsInstance(checkout, SnapshotCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertEqual(checkout._read_timestamp, now)
self.assertIsNone(checkout._min_read_timestamp)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -729,7 +729,7 @@ def test_snapshot_w_min_read_timestamp(self):
checkout = database.snapshot(min_read_timestamp=now)

self.assertIsInstance(checkout, SnapshotCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertEqual(checkout._min_read_timestamp, now)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -750,7 +750,7 @@ def test_snapshot_w_max_staleness(self):
checkout = database.snapshot(max_staleness=staleness)

self.assertIsInstance(checkout, SnapshotCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertIsNone(checkout._min_read_timestamp)
self.assertEqual(checkout._max_staleness, staleness)
Expand All @@ -771,7 +771,7 @@ def test_snapshot_w_exact_staleness(self):
checkout = database.snapshot(exact_staleness=staleness)

self.assertIsInstance(checkout, SnapshotCheckout)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertIsNone(checkout._min_read_timestamp)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -788,7 +788,7 @@ def _getTargetClass(self):
def test_ctor(self):
database = _Database(self.DATABASE_NAME)
checkout = self._make_one(database)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)

def test_context_mgr_success(self):
import datetime
Expand Down Expand Up @@ -865,7 +865,7 @@ def test_ctor_defaults(self):
pool.put(session)

checkout = self._make_one(database)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertIsNone(checkout._min_read_timestamp)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -891,7 +891,7 @@ def test_ctor_w_read_timestamp(self):
pool.put(session)

checkout = self._make_one(database, read_timestamp=now)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertEqual(checkout._read_timestamp, now)
self.assertIsNone(checkout._min_read_timestamp)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -918,7 +918,7 @@ def test_ctor_w_min_read_timestamp(self):
pool.put(session)

checkout = self._make_one(database, min_read_timestamp=now)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertEqual(checkout._min_read_timestamp, now)
self.assertIsNone(checkout._max_staleness)
Expand All @@ -944,7 +944,7 @@ def test_ctor_w_max_staleness(self):
pool.put(session)

checkout = self._make_one(database, max_staleness=staleness)
self.assertTrue(checkout._database is database)
self.assertIs(checkout._database, database)
self.assertIsNone(checkout._read_timestamp)
self.assertIsNone(checkout._min_read_timestamp)
self.assertEqual(checkout._max_staleness, staleness)
Expand Down
16 changes: 8 additions & 8 deletions spanner/tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_constructor_defaults(self):
client = object()
instance = self._make_one(self.INSTANCE_ID, client)
self.assertEqual(instance.instance_id, self.INSTANCE_ID)
self.assertTrue(instance._client is client)
self.assertTrue(instance.configuration_name is None)
self.assertIs(instance._client, client)
self.assertIs(instance.configuration_name, None)
self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT)
self.assertEqual(instance.display_name, self.INSTANCE_ID)

Expand All @@ -64,7 +64,7 @@ def test_constructor_non_default(self):
node_count=self.NODE_COUNT,
display_name=DISPLAY_NAME)
self.assertEqual(instance.instance_id, self.INSTANCE_ID)
self.assertTrue(instance._client is client)
self.assertIs(instance._client, client)
self.assertEqual(instance.configuration_name, self.CONFIG_NAME)
self.assertEqual(instance.node_count, self.NODE_COUNT)
self.assertEqual(instance.display_name, DISPLAY_NAME)
Expand All @@ -78,10 +78,10 @@ def test_copy(self):
new_instance = instance.copy()

# Make sure the client copy succeeded.
self.assertFalse(new_instance._client is client)
self.assertIsNot(new_instance._client, client)
self.assertEqual(new_instance._client, client)
# Make sure the client got copied to a new instance.
self.assertFalse(instance is new_instance)
self.assertIsNot(instance, new_instance)
self.assertEqual(instance, new_instance)

def test__update_from_pb_success(self):
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_database_factory_defaults(self):

self.assertTrue(isinstance(database, Database))
self.assertEqual(database.database_id, DATABASE_ID)
self.assertTrue(database._instance is instance)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertIsInstance(database._pool, BurstyPool)
pool = database._pool
Expand All @@ -516,7 +516,7 @@ def test_database_factory_explicit(self):

self.assertTrue(isinstance(database, Database))
self.assertEqual(database.database_id, DATABASE_ID)
self.assertTrue(database._instance is instance)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS)
self.assertIs(database._pool, pool)
self.assertIs(pool._bound, database)
Expand Down Expand Up @@ -547,7 +547,7 @@ def test_list_databases_wo_paging(self):
instance_name, page_size, options = api._listed_databases
self.assertEqual(instance_name, self.INSTANCE_NAME)
self.assertEqual(page_size, None)
self.assertTrue(options.page_token is INITIAL_PAGE)
self.assertIs(options.page_token, INITIAL_PAGE)
self.assertEqual(options.kwargs['metadata'],
[('google-cloud-resource-prefix', instance.name)])

Expand Down
14 changes: 7 additions & 7 deletions spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _make_one(self, *args, **kwargs):
def test_constructor(self):
database = _Database(self.DATABASE_NAME)
session = self._make_one(database)
self.assertTrue(session.session_id is None)
self.assertTrue(session._database is database)
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)

def test___lt___(self):
database = _Database(self.DATABASE_NAME)
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_snapshot_created(self):
snapshot = session.snapshot()

self.assertIsInstance(snapshot, Snapshot)
self.assertTrue(snapshot._session is session)
self.assertIs(snapshot._session, session)
self.assertTrue(snapshot._strong)

def test_read_not_created(self):
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_batch_created(self):
batch = session.batch()

self.assertIsInstance(batch, Batch)
self.assertTrue(batch._session is session)
self.assertIs(batch._session, session)

def test_transaction_not_created(self):
database = _Database(self.DATABASE_NAME)
Expand All @@ -371,8 +371,8 @@ def test_transaction_created(self):
transaction = session.transaction()

self.assertIsInstance(transaction, Transaction)
self.assertTrue(transaction._session is session)
self.assertTrue(session._transaction is transaction)
self.assertIs(transaction._session, session)
self.assertIs(session._transaction, transaction)

def test_transaction_w_existing_txn(self):
database = _Database(self.DATABASE_NAME)
Expand All @@ -382,7 +382,7 @@ def test_transaction_w_existing_txn(self):
existing = session.transaction()
another = session.transaction() # invalidates existing txn

self.assertTrue(session._transaction is another)
self.assertIs(session._transaction, another)
self.assertTrue(existing._rolled_back)

def test_retry_transaction_w_commit_error_txn_already_begun(self):
Expand Down
Loading

0 comments on commit bbb1255

Please sign in to comment.