From 672a53a79e8ec3db3206b4b68504e2636d0c1ae1 Mon Sep 17 00:00:00 2001 From: Benjamin Himes Date: Fri, 24 May 2024 21:30:06 +0200 Subject: [PATCH] Added tests to cover all compatibility layer cases. --- bittensor/tensor.py | 125 +++++++++-------- bittensor/utils/weight_utils.py | 2 +- tests/unit_tests/test_chain_data.py | 67 +++++++++ tests/unit_tests/test_metagraph.py | 2 +- tests/unit_tests/test_tensor.py | 109 +++++++++++++++ tests/unit_tests/utils/test_weight_utils.py | 147 ++++++++++++++++++++ 6 files changed, 392 insertions(+), 60 deletions(-) diff --git a/bittensor/tensor.py b/bittensor/tensor.py index e6bf6de3d1..ab46560d99 100644 --- a/bittensor/tensor.py +++ b/bittensor/tensor.py @@ -24,33 +24,54 @@ from bittensor.utils.registration import torch, use_torch from pydantic import ConfigDict, BaseModel, Field, field_validator -NUMPY_DTYPES = { - "float16": np.float16, - "float32": np.float32, - "float64": np.float64, - "uint8": np.uint8, - "int16": np.int16, - "int8": np.int8, - "int32": np.int32, - "int64": np.int64, - "bool": bool, -} - -if use_torch(): - TORCH_DTYPES = { - "torch.float16": torch.float16, - "torch.float32": torch.float32, - "torch.float64": torch.float64, - "torch.uint8": torch.uint8, - "torch.int16": torch.int16, - "torch.int8": torch.int8, - "torch.int32": torch.int32, - "torch.int64": torch.int64, - "torch.bool": torch.bool, - } - - -def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: + +class DTypes(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.torch: bool = False + self.update( + { + "float16": np.float16, + "float32": np.float32, + "float64": np.float64, + "uint8": np.uint8, + "int16": np.int16, + "int8": np.int8, + "int32": np.int32, + "int64": np.int64, + "bool": bool, + } + ) + + def __getitem__(self, key): + self._add_torch() + return super().__getitem__(key) + + def __contains__(self, key): + self._add_torch() + return super().__contains__(key) + + def _add_torch(self): + if self.torch is False: + torch_dtypes = { + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.uint8": torch.uint8, + "torch.int16": torch.int16, + "torch.int8": torch.int8, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.bool": torch.bool, + } + self.update(torch_dtypes) + self.torch = True + + +dtypes = DTypes() + + +def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> Optional[str]: """ Casts the raw value to a string representing the `numpy data type `_, or the @@ -67,21 +88,16 @@ def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: """ if not raw: return None - if isinstance(raw, np.dtype): - return NUMPY_DTYPES[raw] - elif use_torch(): - if isinstance(raw, torch.dtype): - return TORCH_DTYPES[raw] + if use_torch() and isinstance(raw, torch.dtype): + return dtypes[raw] + elif isinstance(raw, np.dtype): + return dtypes[raw] elif isinstance(raw, str): if use_torch(): - assert ( - raw in TORCH_DTYPES - ), f"{raw} not a valid torch type in dict {TORCH_DTYPES}" + assert raw in dtypes, f"{raw} not a valid torch type in dict {dtypes}" return raw else: - assert ( - raw in NUMPY_DTYPES - ), f"{raw} not a valid numpy type in dict {NUMPY_DTYPES}" + assert raw in dtypes, f"{raw} not a valid numpy type in dict {dtypes}" return raw else: raise Exception( @@ -89,7 +105,7 @@ def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: ) -def cast_shape(raw: Union[None, List[int], str]) -> str: +def cast_shape(raw: Union[None, List[int], str]) -> Optional[Union[str, list]]: """ Casts the raw value to a string representing the tensor shape. @@ -105,9 +121,7 @@ def cast_shape(raw: Union[None, List[int], str]) -> str: if not raw: return None elif isinstance(raw, list): - if len(raw) == 0: - return raw - elif isinstance(raw[0], int): + if len(raw) == 0 or isinstance(raw[0], int): return raw else: raise Exception(f"{raw} list elements are not of type int") @@ -124,7 +138,7 @@ class tensor: def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]): if isinstance(tensor, list) or isinstance(tensor, np.ndarray): tensor = torch.tensor(tensor) if use_torch() else np.array(tensor) - return Tensor.serialize(tensor=tensor) + return Tensor.serialize(tensor_=tensor) class Tensor(BaseModel): @@ -170,20 +184,20 @@ def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]: # Reshape does not work for (0) or [0] if not (len(shape) == 1 and shape[0] == 0): torch_object = torch_object.reshape(shape) - return torch_object.type(TORCH_DTYPES[self.dtype]) + return torch_object.type(dtypes[self.dtype]) else: # Reshape does not work for (0) or [0] if not (len(shape) == 1 and shape[0] == 0): numpy_object = numpy_object.reshape(shape) - return numpy_object.astype(NUMPY_DTYPES[self.dtype]) + return numpy_object.astype(dtypes[self.dtype]) @staticmethod - def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": + def serialize(tensor_: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": """ Serializes the given tensor. Args: - tensor (np.array or torch.Tensor): The tensor to serialize. + tensor_ (np.array or torch.Tensor): The tensor to serialize. Returns: Tensor: The serialized tensor. @@ -191,19 +205,14 @@ def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": Raises: Exception: If the serialization process encounters an error. """ - dtype = str(tensor.dtype) - shape = list(tensor.shape) + dtype = str(tensor_.dtype) + shape = list(tensor_.shape) if len(shape) == 0: shape = [0] - if use_torch(): - torch_numpy = tensor.cpu().detach().numpy().copy() - data_buffer = base64.b64encode( - msgpack.packb(torch_numpy, default=msgpack_numpy.encode) - ).decode("utf-8") - else: - data_buffer = base64.b64encode( - msgpack.packb(tensor, default=msgpack_numpy.encode) - ).decode("utf-8") + tensor__ = tensor_.cpu().detach().numpy().copy() if use_torch() else tensor_ + data_buffer = base64.b64encode( + msgpack.packb(tensor__, default=msgpack_numpy.encode) + ).decode("utf-8") return Tensor(buffer=data_buffer, shape=shape, dtype=dtype) # Represents the tensor buffer data. diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 9bd8606c9d..2810a9a0c1 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -55,7 +55,7 @@ def normalize_max_weight( if estimation.max() <= limit: return weights / weights.sum() - # Find the cumlative sum and sorted tensor + # Find the cumulative sum and sorted tensor cumsum = np.cumsum(estimation, 0) # Determine the index of cutoff diff --git a/tests/unit_tests/test_chain_data.py b/tests/unit_tests/test_chain_data.py index 2cc842d9a2..a6474bbee9 100644 --- a/tests/unit_tests/test_chain_data.py +++ b/tests/unit_tests/test_chain_data.py @@ -1,5 +1,6 @@ import pytest import bittensor +import torch from bittensor.chain_data import AxonInfo, ChainDataType, DelegateInfo, NeuronInfo SS58_FORMAT = bittensor.__ss58_format__ @@ -204,6 +205,36 @@ def test_to_parameter_dict(axon_info, test_case): assert result[key] == value, f"Test case: {test_case}" +@pytest.mark.parametrize( + "axon_info, test_case", + [ + ( + AxonInfo( + version=1, + ip="127.0.0.1", + port=8080, + ip_type=4, + hotkey="hot", + coldkey="cold", + ), + "ID_to_parameter_dict", + ), + ], +) +def test_to_parameter_dict_torch( + axon_info, + test_case, + force_legacy_torch_compat_api, +): + result = axon_info.to_parameter_dict() + + # Assert + assert isinstance(result, torch.nn.ParameterDict) + for key, value in axon_info.__dict__.items(): + assert key in result + assert result[key] == value, f"Test case: {test_case}" + + @pytest.mark.parametrize( "parameter_dict, expected, test_case", [ @@ -236,6 +267,42 @@ def test_from_parameter_dict(parameter_dict, expected, test_case): assert result == expected, f"Test case: {test_case}" +@pytest.mark.parametrize( + "parameter_dict, expected, test_case", + [ + ( + torch.nn.ParameterDict( + { + "version": 1, + "ip": "127.0.0.1", + "port": 8080, + "ip_type": 4, + "hotkey": "hot", + "coldkey": "cold", + } + ), + AxonInfo( + version=1, + ip="127.0.0.1", + port=8080, + ip_type=4, + hotkey="hot", + coldkey="cold", + ), + "ID_from_parameter_dict", + ), + ], +) +def test_from_parameter_dict_torch( + parameter_dict, expected, test_case, force_legacy_torch_compat_api +): + # Act + result = AxonInfo.from_parameter_dict(parameter_dict) + + # Assert + assert result == expected, f"Test case: {test_case}" + + def create_neuron_info_decoded( hotkey, coldkey, diff --git a/tests/unit_tests/test_metagraph.py b/tests/unit_tests/test_metagraph.py index 38d2cf14cb..af0dbdba76 100644 --- a/tests/unit_tests/test_metagraph.py +++ b/tests/unit_tests/test_metagraph.py @@ -85,7 +85,7 @@ def test_set_metagraph_attributes(mock_environment): metagraph.consensus, np.array([neuron.consensus for neuron in neurons], dtype=np.float32), ) - == True + is True ) # Similarly for other attributes... diff --git a/tests/unit_tests/test_tensor.py b/tests/unit_tests/test_tensor.py index 94d8f7cd52..9939b397e7 100644 --- a/tests/unit_tests/test_tensor.py +++ b/tests/unit_tests/test_tensor.py @@ -18,6 +18,7 @@ import numpy as np import bittensor import numpy +import torch # This is a fixture that creates an example tensor for testing @@ -30,6 +31,15 @@ def example_tensor(): return bittensor.tensor(data) +@pytest.fixture +def example_tensor_torch(force_legacy_torch_compat_api): + # Create a tensor from a list using PyTorch + data = torch.tensor([1, 2, 3, 4]) + + # Serialize the tensor into a Tensor instance and return it + return bittensor.tensor(data) + + def test_deserialize(example_tensor): # Deserialize the tensor from the Tensor instance tensor = example_tensor.deserialize() @@ -39,6 +49,13 @@ def test_deserialize(example_tensor): assert tensor.tolist() == [1, 2, 3, 4] +def test_deserialize_torch(example_tensor_torch, force_legacy_torch_compat_api): + tensor = example_tensor_torch.deserialize() + # Check that the result is a PyTorch tensor with the correct values + assert isinstance(tensor, torch.Tensor) + assert tensor.tolist() == [1, 2, 3, 4] + + def test_serialize(example_tensor): # Check that the serialized tensor is an instance of Tensor assert isinstance(example_tensor, bittensor.Tensor) @@ -70,6 +87,37 @@ def test_serialize(example_tensor): assert example_tensor.shape == example_tensor.shape +def test_serialize_torch(example_tensor_torch, force_legacy_torch_compat_api): + # Check that the serialized tensor is an instance of Tensor + assert isinstance(example_tensor_torch, bittensor.Tensor) + + # Check that the Tensor instance has the correct buffer, dtype, and shape + assert example_tensor_torch.buffer == example_tensor_torch.buffer + assert example_tensor_torch.dtype == example_tensor_torch.dtype + assert example_tensor_torch.shape == example_tensor_torch.shape + + assert isinstance(example_tensor_torch.tolist(), list) + + # Check that the Tensor instance has the correct buffer, dtype, and shape + assert example_tensor_torch.buffer == example_tensor_torch.buffer + assert example_tensor_torch.dtype == example_tensor_torch.dtype + assert example_tensor_torch.shape == example_tensor_torch.shape + + assert isinstance(example_tensor_torch.numpy(), numpy.ndarray) + + # Check that the Tensor instance has the correct buffer, dtype, and shape + assert example_tensor_torch.buffer == example_tensor_torch.buffer + assert example_tensor_torch.dtype == example_tensor_torch.dtype + assert example_tensor_torch.shape == example_tensor_torch.shape + + assert isinstance(example_tensor_torch.tensor(), torch.Tensor) + + # Check that the Tensor instance has the correct buffer, dtype, and shape + assert example_tensor_torch.buffer == example_tensor_torch.buffer + assert example_tensor_torch.dtype == example_tensor_torch.dtype + assert example_tensor_torch.shape == example_tensor_torch.shape + + def test_buffer_field(): # Create a Tensor instance with a specified buffer, dtype, and shape tensor = bittensor.Tensor( @@ -80,6 +128,16 @@ def test_buffer_field(): assert tensor.buffer == "0x321e13edqwds231231231232131" +def test_buffer_field_torch(force_legacy_torch_compat_api): + # Create a Tensor instance with a specified buffer, dtype, and shape + tensor = bittensor.Tensor( + buffer="0x321e13edqwds231231231232131", dtype="torch.float32", shape=[3, 3] + ) + + # Check that the buffer field matches the provided value + assert tensor.buffer == "0x321e13edqwds231231231232131" + + def test_dtype_field(): # Create a Tensor instance with a specified buffer, dtype, and shape tensor = bittensor.Tensor( @@ -90,6 +148,13 @@ def test_dtype_field(): assert tensor.dtype == "float32" +def test_dtype_field_torch(force_legacy_torch_compat_api): + tensor = bittensor.Tensor( + buffer="0x321e13edqwds231231231232131", dtype="torch.float32", shape=[3, 3] + ) + assert tensor.dtype == "torch.float32" + + def test_shape_field(): # Create a Tensor instance with a specified buffer, dtype, and shape tensor = bittensor.Tensor( @@ -100,6 +165,13 @@ def test_shape_field(): assert tensor.shape == [3, 3] +def test_shape_field_torch(force_legacy_torch_compat_api): + tensor = bittensor.Tensor( + buffer="0x321e13edqwds231231231232131", dtype="torch.float32", shape=[3, 3] + ) + assert tensor.shape == [3, 3] + + def test_serialize_all_types(): bittensor.tensor(np.array([1], dtype=np.float16)) bittensor.tensor(np.array([1], dtype=np.float32)) @@ -110,6 +182,16 @@ def test_serialize_all_types(): bittensor.tensor(np.array([1], dtype=bool)) +def test_serialize_all_types_torch(force_legacy_torch_compat_api): + bittensor.tensor(torch.tensor([1], dtype=torch.float16)) + bittensor.tensor(torch.tensor([1], dtype=torch.float32)) + bittensor.tensor(torch.tensor([1], dtype=torch.float64)) + bittensor.tensor(torch.tensor([1], dtype=torch.uint8)) + bittensor.tensor(torch.tensor([1], dtype=torch.int32)) + bittensor.tensor(torch.tensor([1], dtype=torch.int64)) + bittensor.tensor(torch.tensor([1], dtype=torch.bool)) + + def test_serialize_all_types_equality(): rng = np.random.default_rng() @@ -132,3 +214,30 @@ def test_serialize_all_types_equality(): tensor = rng.standard_normal((100,), dtype=np.float32) < 0.5 assert np.all(bittensor.tensor(tensor).tensor() == tensor) + + +def test_serialize_all_types_equality_torch(force_legacy_torch_compat_api): + torchtensor = torch.randn([100], dtype=torch.float16) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randn([100], dtype=torch.float32) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randn([100], dtype=torch.float64) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randint(255, 256, (1000,), dtype=torch.uint8) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randint( + 2_147_483_646, 2_147_483_647, (1000,), dtype=torch.int32 + ) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randint( + 9_223_372_036_854_775_806, 9_223_372_036_854_775_807, (1000,), dtype=torch.int64 + ) + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) + + torchtensor = torch.randn([100], dtype=torch.float32) < 0.5 + assert torch.all(bittensor.tensor(torchtensor).tensor() == torchtensor) diff --git a/tests/unit_tests/utils/test_weight_utils.py b/tests/unit_tests/utils/test_weight_utils.py index 0a42a9c9b3..edf334db50 100644 --- a/tests/unit_tests/utils/test_weight_utils.py +++ b/tests/unit_tests/utils/test_weight_utils.py @@ -56,6 +56,36 @@ def test_convert_weight_and_uids(): weight_utils.convert_weights_and_uids_for_emit(uids, weights) +def test_convert_weight_and_uids_torch(force_legacy_torch_compat_api): + uids = torch.tensor(list(range(10))) + weights = torch.rand(10) + weight_utils.convert_weights_and_uids_for_emit(uids, weights) + + # min weight < 0 + weights[5] = -1 + with pytest.raises(ValueError) as pytest_wrapped_e: + weight_utils.convert_weights_and_uids_for_emit(uids, weights) + # min uid < 0 + weights[5] = 0 + uids[3] = -1 + with pytest.raises(ValueError) as pytest_wrapped_e: + weight_utils.convert_weights_and_uids_for_emit(uids, weights) + # len(uids) != len(weights) + uids[3] = 3 + with pytest.raises(ValueError) as pytest_wrapped_e: + weight_utils.convert_weights_and_uids_for_emit(uids, weights[1:]) + + # sum(weights) == 0 + weights = torch.zeros(10) + weight_utils.convert_weights_and_uids_for_emit(uids, weights) + + # test for overflow and underflow + for _ in range(5): + uids = torch.tensor(list(range(10))) + weights = torch.rand(10) + weight_utils.convert_weights_and_uids_for_emit(uids, weights) + + def test_normalize_with_max_weight(): weights = np.random.rand(1000) wn = weight_utils.normalize_max_weight(weights, limit=0.01) @@ -187,6 +217,37 @@ def test_convert_weight_uids_and_vals_to_tensor_happy_path( assert np.allclose(result, expected), f"Failed {test_id}" +@pytest.mark.parametrize( + "test_id, n, uids, weights, subnets, expected", + [ + ( + "happy-path-1", + 3, + [0, 1, 2], + [15, 5, 80], + [0, 1, 2], + torch.tensor([0.15, 0.05, 0.8]), + ), + ( + "happy-path-2", + 3, + [0, 2], + [300, 300], + [0, 1, 2], + torch.tensor([0.5, 0.0, 0.5]), + ), + ], +) +def test_convert_weight_uids_and_vals_to_tensor_happy_path_torch( + test_id, n, uids, weights, subnets, expected, force_legacy_torch_compat_api +): + # Act + result = weight_utils.convert_weight_uids_and_vals_to_tensor(n, uids, weights) + + # Assert + assert torch.allclose(result, expected), f"Failed {test_id}" + + @pytest.mark.parametrize( "test_id, n, uids, weights, expected", [ @@ -254,6 +315,39 @@ def test_convert_root_weight_uids_and_vals_to_tensor_happy_paths( assert np.allclose(result, expected, atol=1e-4), f"Failed {test_id}" +@pytest.mark.parametrize( + "test_id, n, uids, weights, subnets, expected", + [ + ( + "edge-1", + 1, + [0], + [0], + [0], + torch.tensor([0.0]), + ), # Single neuron with zero weight + ( + "edge-2", + 2, + [0, 1], + [0, 0], + [0, 1], + torch.tensor([0.0, 0.0]), + ), # All zero weights + ], +) +def test_convert_root_weight_uids_and_vals_to_tensor_edge_cases( + test_id, n, uids, weights, subnets, expected, force_legacy_torch_compat_api +): + # Act + result = weight_utils.convert_root_weight_uids_and_vals_to_tensor( + n, uids, weights, subnets + ) + + # Assert + assert torch.allclose(result, expected, atol=1e-4), f"Failed {test_id}" + + @pytest.mark.parametrize( "test_id, n, uids, weights, subnets, expected", [ @@ -333,6 +427,36 @@ def test_happy_path(test_id, n, uids, bonds, expected_output): assert np.array_equal(result, expected_output), f"Failed {test_id}" +@pytest.mark.parametrize( + "test_id, n, uids, bonds, expected_output", + [ + ( + "happy-path-1", + 5, + [1, 3, 4], + [10, 20, 30], + torch.tensor([0, 10, 0, 20, 30], dtype=torch.int64), + ), + ( + "happy-path-2", + 3, + [0, 1, 2], + [7, 8, 9], + torch.tensor([7, 8, 9], dtype=torch.int64), + ), + ("happy-path-3", 4, [2], [15], torch.tensor([0, 0, 15, 0], dtype=torch.int64)), + ], +) +def test_happy_path_torch( + test_id, n, uids, bonds, expected_output, force_legacy_torch_compat_api +): + # Act + result = weight_utils.convert_bond_uids_and_vals_to_tensor(n, uids, bonds) + + # Assert + assert torch.equal(result, expected_output), f"Failed {test_id}" + + @pytest.mark.parametrize( "test_id, n, uids, bonds, expected_output", [ @@ -354,6 +478,29 @@ def test_edge_cases(test_id, n, uids, bonds, expected_output): assert np.array_equal(result, expected_output), f"Failed {test_id}" +@pytest.mark.parametrize( + "test_id, n, uids, bonds, expected_output", + [ + ("edge-1", 1, [0], [0], torch.tensor([0], dtype=torch.int64)), # Single element + ( + "edge-2", + 10, + [], + [], + torch.zeros(10, dtype=torch.int64), + ), # Empty uids and bonds + ], +) +def test_edge_cases_torch( + test_id, n, uids, bonds, expected_output, force_legacy_torch_compat_api +): + # Act + result = weight_utils.convert_bond_uids_and_vals_to_tensor(n, uids, bonds) + + # Assert + assert torch.equal(result, expected_output), f"Failed {test_id}" + + @pytest.mark.parametrize( "test_id, n, uids, bonds, exception", [