Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't use astype for torch.Tensor #2242

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions bittensor/metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import bittensor
from os import listdir
from os.path import join
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Union, Tuple, cast

from bittensor.chain_data import AxonInfo
from bittensor.utils.registration import torch, use_torch
from bittensor.utils import weight_utils

METAGRAPH_STATE_DICT_NDARRAY_KEYS = [
"version",
Expand Down Expand Up @@ -648,33 +649,40 @@ def _process_weights_or_bonds(
self.weights = self._process_weights_or_bonds(raw_weights_data, "weights")
"""
data_array = []
data_array: list[Union[NDArray[np.float32], "torch.Tensor"]] = []
for item in data:
if len(item) == 0:
if use_torch():
data_array.append(torch.zeros(len(self.neurons))) # type: ignore
data_array.append(torch.zeros(len(self.neurons)))
else:
data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) # type: ignore
data_array.append(np.zeros(len(self.neurons), dtype=np.float32))
else:
uids, values = zip(*item)
# TODO: Validate and test the conversion of uids and values to tensor
if attribute == "weights":
data_array.append(
bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor(
weight_utils.convert_weight_uids_and_vals_to_tensor(
len(self.neurons),
list(uids),
list(values), # type: ignore
list(values),
)
)
else:
data_array.append(
bittensor.utils.weight_utils.convert_bond_uids_and_vals_to_tensor( # type: ignore
len(self.neurons), list(uids), list(values)
).astype(np.float32)
da_item = weight_utils.convert_bond_uids_and_vals_to_tensor(
len(self.neurons), list(uids), list(values)
)
if use_torch():
data_array.append(cast("torch.LongTensor", da_item))
else:
data_array.append(
cast(NDArray[np.float32], da_item).astype(np.float32)
)
tensor_param: Union["torch.nn.Parameter", NDArray] = (
(
torch.nn.Parameter(torch.stack(data_array), requires_grad=False)
torch.nn.Parameter(
torch.stack(cast(list["torch.Tensor"], data_array)),
requires_grad=False,
)
if len(data_array)
else torch.nn.Parameter()
)
Expand Down Expand Up @@ -730,7 +738,7 @@ def _process_root_weights(
uids, values = zip(*item)
# TODO: Validate and test the conversion of uids and values to tensor
data_array.append(
bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore
weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore
n_subnets, list(uids), list(values), subnets
)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/test_metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,37 @@ def test_process_weights_or_bonds(mock_environment):
# TODO: Add more checks to ensure the bonds have been processed correctly


def test_process_weights_or_bonds_torch(
mock_environment, force_legacy_torch_compat_api
):
_, neurons = mock_environment
metagraph = bittensor.metagraph(1, sync=False)
metagraph.neurons = neurons

# Test weights processing
weights = metagraph._process_weights_or_bonds(
data=[neuron.weights for neuron in neurons], attribute="weights"
)
assert weights.shape[0] == len(
neurons
) # Number of rows should be equal to number of neurons
assert weights.shape[1] == len(
neurons
) # Number of columns should be equal to number of neurons
# TODO: Add more checks to ensure the weights have been processed correctly

# Test bonds processing
bonds = metagraph._process_weights_or_bonds(
data=[neuron.bonds for neuron in neurons], attribute="bonds"
)
assert bonds.shape[0] == len(
neurons
) # Number of rows should be equal to number of neurons
assert bonds.shape[1] == len(
neurons
) # Number of columns should be equal to number of neurons


# Mocking the bittensor.subtensor class for testing purposes
@pytest.fixture
def mock_subtensor():
Expand Down