From 1301d0dfddaae70af3d4bfffa5e33be79e73deea Mon Sep 17 00:00:00 2001 From: Danny Allen Date: Mon, 14 Dec 2020 13:13:52 -0800 Subject: [PATCH] [dvs] Clean-up dvs_database and dvs_common (#1541) - Fix formatting for default assert messages - Add support for customized assert messages - Use dataclass for PollingConfig Signed-off-by: Danny Allen --- tests/dvslib/dvs_common.py | 38 ++++--- tests/dvslib/dvs_database.py | 210 +++++++++++++++++++++-------------- tests/dvslib/dvs_vlan.py | 4 +- tests/test_fgnhg.py | 13 +-- tests/test_nat.py | 3 +- tests/test_route.py | 11 +- tests/test_sub_port_intf.py | 7 +- 7 files changed, 165 insertions(+), 121 deletions(-) diff --git a/tests/dvslib/dvs_common.py b/tests/dvslib/dvs_common.py index ba8ef8c26b50..7edae4e79297 100644 --- a/tests/dvslib/dvs_common.py +++ b/tests/dvslib/dvs_common.py @@ -1,26 +1,34 @@ """Common infrastructure for writing VS tests.""" -import collections import time +from dataclasses import dataclass from typing import Any, Callable, Tuple -_PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict') - -class PollingConfig(_PollingConfig): - """PollingConfig provides parameters that are used to control polling behavior. +@dataclass +class PollingConfig: + """Class containing parameters that are used to control polling behavior. Attributes: - polling_interval (int): How often to poll, in seconds. - timeout (int): The maximum amount of time to wait, in seconds. - strict (bool): If the strict flag is set, reaching the timeout will cause tests to fail. + polling_interval: How often to poll, in seconds. + timeout: The maximum amount of time to wait, in seconds. + strict: If the strict flag is set, reaching the timeout will cause tests to fail. """ + polling_interval: float = 0.01 + timeout: float = 5.00 + strict: bool = True + + def iterations(self) -> int: + """Return the number of iterations needed to poll with the given interval and timeout.""" + return 1 if self.polling_interval == 0 else int(self.timeout // self.polling_interval) + 1 + def wait_for_result( polling_function: Callable[[], Tuple[bool, Any]], - polling_config: PollingConfig, + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Tuple[bool, Any]: """Run `polling_function` periodically using the specified `polling_config`. @@ -29,6 +37,8 @@ def wait_for_result( must return a status which indicates if the function was succesful or not, as well as some return value. polling_config: The parameters to use to poll the polling function. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: If the polling function succeeds, then this method will return True and the output of the @@ -37,12 +47,7 @@ def wait_for_result( If it does not succeed within the provided timeout, it will return False and whatever the output of the polling function was on the final attempt. """ - if polling_config.polling_interval == 0: - iterations = 1 - else: - iterations = int(polling_config.timeout // polling_config.polling_interval) + 1 - - for _ in range(iterations): + for _ in range(polling_config.iterations()): status, result = polling_function() if status: @@ -51,6 +56,7 @@ def wait_for_result( time.sleep(polling_config.polling_interval) if polling_config.strict: - assert False, f"Operation timed out after {polling_config.timeout} seconds" + message = failure_message or f"Operation timed out after {polling_config.timeout} seconds" + assert False, message return (False, result) diff --git a/tests/dvslib/dvs_database.py b/tests/dvslib/dvs_database.py index e5d9f99fbbbe..7d268e5a6d69 100644 --- a/tests/dvslib/dvs_database.py +++ b/tests/dvslib/dvs_database.py @@ -2,7 +2,6 @@ FIXME: - Reference DBs by name rather than ID/socket - - Move DEFAULT_POLLING_CONFIG to Common - Add support for ProducerStateTable """ from typing import Dict, List @@ -11,13 +10,7 @@ class DVSDatabase: - """DVSDatabase provides access to redis databases on the virtual switch. - - By default, database operations are configured to use `DEFAULT_POLLING_CONFIG`. Users can - specify their own PollingConfig, but this shouldn't typically be necessary. - """ - - DEFAULT_POLLING_CONFIG = PollingConfig(polling_interval=0.01, timeout=5, strict=True) + """DVSDatabase provides access to redis databases on the virtual switch.""" def __init__(self, db_id: int, connector: str): """Initialize a DVSDatabase instance. @@ -99,7 +92,8 @@ def wait_for_entry( self, table_name: str, key: str, - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: """Wait for the entry stored at `key` in the specified table to exist and retrieve it. @@ -107,21 +101,19 @@ def wait_for_entry( table_name: The name of the table where the entry is stored. key: The key that maps to the entry being retrieved. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + + def access_function(): fv_pairs = self.get_entry(table_name, key) return (bool(fv_pairs), fv_pairs) - status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) - - if not status: - assert not polling_config.strict, \ - f"Entry not found: key=\"{key}\", table=\"{table_name}\"" + message = failure_message or f'Entry not found: key="{key}", table="{table_name}"' + _, result = wait_for_result(access_function, polling_config, message) return result @@ -130,7 +122,8 @@ def wait_for_fields( table_name: str, key: str, expected_fields: List[str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: """Wait for the entry stored at `key` to have the specified fields and retrieve it. @@ -142,22 +135,27 @@ def wait_for_fields( key: The key that maps to the entry being checked. expected_fields: The fields that we expect to see in the entry. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + + def access_function(): fv_pairs = self.get_entry(table_name, key) return (all(field in fv_pairs for field in expected_fields), fv_pairs) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Expected fields not found: expected={expected_fields}, \ - received={result}, key=\"{key}\", table=\"{table_name}\"" + message = failure_message or ( + f"Expected fields not found: expected={expected_fields}, received={result}, " + f'key="{key}", table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -166,34 +164,43 @@ def wait_for_field_match( table_name: str, key: str, expected_fields: Dict[str, str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: - """Wait for the entry stored at `key` to have the specified field/value pairs and retrieve it. + """Wait for the entry stored at `key` to have the specified field/values and retrieve it. - This method is useful if you only care about the contents of a subset of the fields stored in the - specified entry. + This method is useful if you only care about the contents of a subset of the fields stored + in the specified entry. Args: table_name: The name of the table where the entry is stored. key: The key that maps to the entry being checked. expected_fields: The fields and their values we expect to see in the entry. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + + def access_function(): fv_pairs = self.get_entry(table_name, key) - return (all(fv_pairs.get(k) == v for k, v in expected_fields.items()), fv_pairs) + return ( + all(fv_pairs.get(k) == v for k, v in expected_fields.items()), + fv_pairs, + ) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Expected field/value pairs not found: expected={expected_fields}, \ - received={result}, key=\"{key}\", table=\"{table_name}\"" + message = failure_message or ( + f"Expected field/value pairs not found: expected={expected_fields}, " + f'received={result}, key="{key}", table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -202,33 +209,43 @@ def wait_for_field_negative_match( table_name: str, key: str, old_fields: Dict[str, str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: - """Wait for the entry stored at `key` to have different field/value pairs than the ones specified. + """Wait for the entry stored at `key` to have different field/values than `old_fields`. - This method is useful if you expect some field to change, but you don't know their exact values. + This method is useful if you expect some field to change, but you don't know their exact + values. Args: table_name: The name of the table where the entry is stored. key: The key that maps to the entry being checked. old_fields: The original field/value pairs we expect to change. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + + def access_function(): fv_pairs = self.get_entry(table_name, key) - return (all(k in fv_pairs and fv_pairs[k] != v for k, v in old_fields.items()), fv_pairs) + return ( + all(k in fv_pairs and fv_pairs[k] != v for k, v in old_fields.items()), + fv_pairs, + ) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Did not expect field/values to match, but they did: provided={old_fields}, \ - received={result}, key=\"{key}\", table=\"{table_name}\"" + message = failure_message or ( + f"Did not expect field/values to match, but they did: provided={old_fields}, " + f'received={result}, key="{key}", table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -237,7 +254,8 @@ def wait_for_exact_match( table_name: str, key: str, expected_entry: Dict[str, str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: """Wait for the entry stored at `key` to match `expected_entry` and retrieve it. @@ -248,23 +266,27 @@ def wait_for_exact_match( key: The key that maps to the entry being checked. expected_entry: The entry we expect to see. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + def access_function(): fv_pairs = self.get_entry(table_name, key) return (fv_pairs == expected_entry, fv_pairs) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Exact match not found: expected={expected_entry}, received={result}, \ - key=\"{key}\", table=\"{table_name}\"" + message = failure_message or ( + f"Exact match not found: expected={expected_entry}, received={result}, " + f'key="{key}", table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -272,7 +294,8 @@ def wait_for_deleted_entry( self, table_name: str, key: str, - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> Dict[str, str]: """Wait for no entry to exist at `key` in the specified table. @@ -280,21 +303,26 @@ def wait_for_deleted_entry( table_name: The name of the table being checked. key: The key to be checked. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The entry stored at `key`. If no entry is found, then an empty Dict is returned. """ - def __access_function(): + + def access_function(): fv_pairs = self.get_entry(table_name, key) return (not bool(fv_pairs), fv_pairs) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Entry still exists: entry={result}, key=\"{key}\", table=\"{table_name}\"" + message = failure_message or ( + f'Entry still exists: entry={result}, key="{key}", table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -302,7 +330,8 @@ def wait_for_n_keys( self, table_name: str, num_keys: int, - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> List[str]: """Wait for the specified number of keys to exist in the table. @@ -310,22 +339,27 @@ def wait_for_n_keys( table_name: The name of the table from which to fetch the keys. num_keys: The expected number of keys to retrieve from the table. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The keys stored in the table. If no keys are found, then an empty List is returned. """ - def __access_function(): + + def access_function(): keys = self.get_keys(table_name) return (len(keys) == num_keys, keys) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Unexpected number of keys: expected={num_keys}, \ - received={len(result)} ({result}), table=\"{table_name}\"" + message = failure_message or ( + f"Unexpected number of keys: expected={num_keys}, received={len(result)} " + f'({result}), table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -333,7 +367,8 @@ def wait_for_matching_keys( self, table_name: str, expected_keys: List[str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> List[str]: """Wait for the specified keys to exist in the table. @@ -341,22 +376,27 @@ def wait_for_matching_keys( table_name: The name of the table from which to fetch the keys. expected_keys: The keys we expect to see in the table. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The keys stored in the table. If no keys are found, then an empty List is returned. """ - def __access_function(): + + def access_function(): keys = self.get_keys(table_name) return (all(key in keys for key in expected_keys), keys) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: - assert not polling_config.strict, \ - f"Expected keys not found: expected={expected_keys}, received={result}, \ - table=\"{table_name}\"" + message = failure_message or ( + f"Expected keys not found: expected={expected_keys}, received={result}, " + f'table="{table_name}"' + ) + assert not polling_config.strict, message return result @@ -364,7 +404,8 @@ def wait_for_deleted_keys( self, table_name: str, deleted_keys: List[str], - polling_config: PollingConfig = DEFAULT_POLLING_CONFIG + polling_config: PollingConfig = PollingConfig(), + failure_message: str = None, ) -> List[str]: """Wait for the specfied keys to no longer exist in the table. @@ -372,29 +413,34 @@ def wait_for_deleted_keys( table_name: The name of the table from which to fetch the keys. deleted_keys: The keys we expect to be removed from the table. polling_config: The parameters to use to poll the db. + failure_message: The message to print if the call times out. This will only take effect + if the PollingConfig is set to strict. Returns: The keys stored in the table. If no keys are found, then an empty List is returned. """ - def __access_function(): + + def access_function(): keys = self.get_keys(table_name) return (all(key not in keys for key in deleted_keys), keys) status, result = wait_for_result( - __access_function, - self._disable_strict_polling(polling_config)) + access_function, self._disable_strict_polling(polling_config) + ) if not status: expected = [key for key in result if key not in deleted_keys] - assert not polling_config.strict, \ - f"Unexpected keys found: expected={expected}, received={result}, \ - table=\"{table_name}\"" + message = failure_message or ( + f"Unexpected keys found: expected={expected}, received={result}, " + f'table="{table_name}"' + ) + assert not polling_config.strict, message return result @staticmethod def _disable_strict_polling(polling_config: PollingConfig) -> PollingConfig: - disabled_config = PollingConfig(polling_interval=polling_config.polling_interval, - timeout=polling_config.timeout, - strict=False) + disabled_config = PollingConfig( + polling_config.polling_interval, polling_config.timeout, False + ) return disabled_config diff --git a/tests/dvslib/dvs_vlan.py b/tests/dvslib/dvs_vlan.py index adca18b0e72b..ecb71ad98a0d 100644 --- a/tests/dvslib/dvs_vlan.py +++ b/tests/dvslib/dvs_vlan.py @@ -1,4 +1,4 @@ -from .dvs_database import DVSDatabase +from .dvs_common import PollingConfig class DVSVlan(object): def __init__(self, adb, cdb, sdb, cntrdb, appdb): @@ -49,7 +49,7 @@ def verify_vlan(self, vlan_oid, vlan_id): def get_and_verify_vlan_ids(self, expected_num, - polling_config=DVSDatabase.DEFAULT_POLLING_CONFIG): + polling_config=PollingConfig()): vlan_entries = self.asic_db.wait_for_n_keys("ASIC_STATE:SAI_OBJECT_TYPE_VLAN", expected_num + 1, polling_config) diff --git a/tests/test_fgnhg.py b/tests/test_fgnhg.py index 4c1a6a335c7a..89cd46eeabc2 100644 --- a/tests/test_fgnhg.py +++ b/tests/test_fgnhg.py @@ -5,7 +5,6 @@ import pytest from dvslib.dvs_common import wait_for_result -from dvslib.dvs_database import DVSDatabase from swsscommon import swsscommon IF_TB = 'INTERFACE' @@ -98,10 +97,7 @@ def _access_function(): if nhg_type != "SAI_NEXT_HOP_GROUP_TYPE_DYNAMIC_UNORDERED_ECMP": return false_ret return (True, nhgid) - status, result = wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) - if not status: - assert not polling_config.strict, \ - f"SAI_NEXT_HOP_GROUP_TYPE_DYNAMIC_UNORDERED_ECMP not found" + _, result = wait_for_result(_access_function, failure_message="SAI_NEXT_HOP_GROUP_TYPE_DYNAMIC_UNORDERED_ECMP not found") return result @@ -159,10 +155,9 @@ def _access_function(): ret = ret and (idx == 1) return (ret, nh_memb_count) - status, result = wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) - if not status: - assert not polling_config.strict, \ - f"Exact match not found: expected={nh_memb_exp_count}, received={result}" + status, result = wait_for_result(_access_function) + assert status, f"Exact match not found: expected={nh_memb_exp_count}, received={result}" + return result diff --git a/tests/test_nat.py b/tests/test_nat.py index b6492273ffff..6f8606d67bd8 100644 --- a/tests/test_nat.py +++ b/tests/test_nat.py @@ -1,7 +1,6 @@ import time from dvslib.dvs_common import wait_for_result -from dvslib.dvs_database import DVSDatabase class TestNat(object): @@ -316,7 +315,7 @@ def _check_conntrack_for_static_entry(): return (True, None) - wait_for_result(_check_conntrack_for_static_entry, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_check_conntrack_for_static_entry) # delete a static nat entry dvs.runcmd("config nat remove static basic 67.66.65.1 18.18.18.2") diff --git a/tests/test_route.py b/tests/test_route.py index b42ae6212f64..d7bef08cc12d 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -6,7 +6,6 @@ from swsscommon import swsscommon from dvslib.dvs_common import wait_for_result -from dvslib.dvs_database import DVSDatabase class TestRouteBase(object): def setup_db(self, dvs): @@ -61,7 +60,7 @@ def _access_function(): for route_entry in route_entries] return (all(destination in route_destinations for destination in destinations), None) - wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function) def check_route_entries_with_vrf(self, destinations, vrf_oids): def _access_function(): @@ -71,7 +70,7 @@ def _access_function(): return (all((destination, vrf_oid) in route_destination_vrf for destination, vrf_oid in zip(destinations, vrf_oids)), None) - wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function) def check_route_entries_nexthop(self, destinations, vrf_oids, nexthops): def _access_function_nexthop(): @@ -80,7 +79,7 @@ def _access_function_nexthop(): for key in nexthop_entries]) return (all(nexthop in nexthop_oids for nexthop in nexthops), nexthop_oids) - status, nexthop_oids = wait_for_result(_access_function_nexthop, DVSDatabase.DEFAULT_POLLING_CONFIG) + status, nexthop_oids = wait_for_result(_access_function_nexthop) def _access_function_route_nexthop(): route_entries = self.adb.get_keys("ASIC_STATE:SAI_OBJECT_TYPE_ROUTE_ENTRY") @@ -90,7 +89,7 @@ def _access_function_route_nexthop(): return (all(route_destination_nexthop.get((destination, vrf_oid)) == nexthop_oids.get(nexthop) for destination, vrf_oid, nexthop in zip(destinations, vrf_oids, nexthops)), None) - wait_for_result(_access_function_route_nexthop, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function_route_nexthop) def check_deleted_route_entries(self, destinations): def _access_function(): @@ -98,7 +97,7 @@ def _access_function(): route_destinations = [json.loads(route_entry)["dest"] for route_entry in route_entries] return (all(destination not in route_destinations for destination in destinations), None) - wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function) def clear_srv_config(self, dvs): dvs.servers[0].runcmd("ip address flush dev eth0") diff --git a/tests/test_sub_port_intf.py b/tests/test_sub_port_intf.py index ec1b88cb8cb5..8666296502f1 100644 --- a/tests/test_sub_port_intf.py +++ b/tests/test_sub_port_intf.py @@ -2,7 +2,6 @@ import time from dvslib.dvs_common import wait_for_result -from dvslib.dvs_database import DVSDatabase from swsscommon import swsscommon DEFAULT_MTU = "9100" @@ -138,7 +137,7 @@ def _access_function(): return (route_entry_found, raw_route_entry_key) - (route_entry_found, raw_route_entry_key) = wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + (route_entry_found, raw_route_entry_key) = wait_for_result(_access_function) fvs = self.asic_db.get_entry(ASIC_ROUTE_ENTRY_TABLE, raw_route_entry_key) @@ -166,7 +165,7 @@ def _access_function(): for raw_route_entry in raw_route_entries] return (all(dest in route_destinations for dest in expected_destinations), None) - wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function) def check_sub_port_intf_key_removal(self, db, table_name, key): db.wait_for_deleted_keys(table_name, [key]) @@ -179,7 +178,7 @@ def _access_function(): for raw_route_entry in raw_route_entries) return (status, None) - wait_for_result(_access_function, DVSDatabase.DEFAULT_POLLING_CONFIG) + wait_for_result(_access_function) def _test_sub_port_intf_creation(self, dvs, sub_port_intf_name): substrs = sub_port_intf_name.split(VLAN_SUB_INTERFACE_SEPARATOR)