Skip to content

Commit

Permalink
Merge pull request #2311 from opentensor/feat/roman/remove-class-tensor
Browse files Browse the repository at this point in the history
remove unused code (tensor.py-> class tensor), remove old tests, add new tests
  • Loading branch information
roman-opentensor authored Sep 17, 2024
2 parents 29e7d82 + 10dee4a commit e8ab3b9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
9 changes: 1 addition & 8 deletions bittensor/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,6 @@ def cast_shape(raw: Union[None, List[int], str]) -> Optional[Union[str, list]]:
)


class tensor:
def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]):
if isinstance(tensor, list) or isinstance(tensor, np.ndarray):
tensor = torch.tensor(tensor) if use_torch() else np.array(tensor)
return Tensor.serialize(tensor_=tensor)


class Tensor(BaseModel):
"""
Represents a Tensor object.
Expand All @@ -158,7 +151,7 @@ def tensor(self) -> Union[np.ndarray, "torch.Tensor"]:
def tolist(self) -> List[object]:
return self.deserialize().tolist()

def numpy(self) -> "numpy.ndarray":
def numpy(self) -> "np.ndarray":
return (
self.deserialize().detach().numpy() if use_torch() else self.deserialize()
)
Expand Down
2 changes: 1 addition & 1 deletion bittensor/utils/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from bittensor.core.stream import StreamingSynapse # noqa: F401
from bittensor.core.subtensor import Subtensor
from bittensor.core.synapse import TerminalInfo, Synapse # noqa: F401
from bittensor.core.tensor import tensor, Tensor # noqa: F401
from bittensor.core.tensor import Tensor # noqa: F401
from bittensor.core.threadpool import ( # noqa: F401
PriorityThreadPoolExecutor as PriorityThreadPoolExecutor,
)
Expand Down
60 changes: 30 additions & 30 deletions tests/unit_tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch

from bittensor.core.tensor import tensor as tensor_class, Tensor
from bittensor.core.tensor import Tensor


# This is a fixture that creates an example tensor for testing
Expand All @@ -30,7 +30,7 @@ def example_tensor():
data = np.array([1, 2, 3, 4])

# Serialize the tensor into a Tensor instance and return it
return tensor_class(data)
return Tensor.serialize(data)


@pytest.fixture
Expand All @@ -39,7 +39,7 @@ def example_tensor_torch(force_legacy_torch_compatible_api):
data = torch.tensor([1, 2, 3, 4])

# Serialize the tensor into a Tensor instance and return it
return tensor_class(data)
return Tensor.serialize(data)


def test_deserialize(example_tensor):
Expand Down Expand Up @@ -175,71 +175,71 @@ def test_shape_field_torch(force_legacy_torch_compatible_api):


def test_serialize_all_types():
tensor_class(np.array([1], dtype=np.float16))
tensor_class(np.array([1], dtype=np.float32))
tensor_class(np.array([1], dtype=np.float64))
tensor_class(np.array([1], dtype=np.uint8))
tensor_class(np.array([1], dtype=np.int32))
tensor_class(np.array([1], dtype=np.int64))
tensor_class(np.array([1], dtype=bool))
Tensor.serialize(np.array([1], dtype=np.float16))
Tensor.serialize(np.array([1], dtype=np.float32))
Tensor.serialize(np.array([1], dtype=np.float64))
Tensor.serialize(np.array([1], dtype=np.uint8))
Tensor.serialize(np.array([1], dtype=np.int32))
Tensor.serialize(np.array([1], dtype=np.int64))
Tensor.serialize(np.array([1], dtype=bool))


def test_serialize_all_types_torch(force_legacy_torch_compatible_api):
tensor_class(torch.tensor([1], dtype=torch.float16))
tensor_class(torch.tensor([1], dtype=torch.float32))
tensor_class(torch.tensor([1], dtype=torch.float64))
tensor_class(torch.tensor([1], dtype=torch.uint8))
tensor_class(torch.tensor([1], dtype=torch.int32))
tensor_class(torch.tensor([1], dtype=torch.int64))
tensor_class(torch.tensor([1], dtype=torch.bool))
Tensor.serialize(torch.tensor([1], dtype=torch.float16))
Tensor.serialize(torch.tensor([1], dtype=torch.float32))
Tensor.serialize(torch.tensor([1], dtype=torch.float64))
Tensor.serialize(torch.tensor([1], dtype=torch.uint8))
Tensor.serialize(torch.tensor([1], dtype=torch.int32))
Tensor.serialize(torch.tensor([1], dtype=torch.int64))
Tensor.serialize(torch.tensor([1], dtype=torch.bool))


def test_serialize_all_types_equality():
rng = np.random.default_rng()

tensor = rng.standard_normal((100,), dtype=np.float32)
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)

tensor = rng.standard_normal((100,), dtype=np.float64)
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)

tensor = np.random.randint(255, 256, (1000,), dtype=np.uint8)
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)

tensor = np.random.randint(2_147_483_646, 2_147_483_647, (1000,), dtype=np.int32)
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)

tensor = np.random.randint(
9_223_372_036_854_775_806, 9_223_372_036_854_775_807, (1000,), dtype=np.int64
)
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)

tensor = rng.standard_normal((100,), dtype=np.float32) < 0.5
assert np.all(tensor_class(tensor).tensor() == tensor)
assert np.all(Tensor.serialize(tensor).tensor() == tensor)


def test_serialize_all_types_equality_torch(force_legacy_torch_compatible_api):
torchtensor = torch.randn([100], dtype=torch.float16)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randn([100], dtype=torch.float32)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randn([100], dtype=torch.float64)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randint(255, 256, (1000,), dtype=torch.uint8)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randint(
2_147_483_646, 2_147_483_647, (1000,), dtype=torch.int32
)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randint(
9_223_372_036_854_775_806, 9_223_372_036_854_775_807, (1000,), dtype=torch.int64
)
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

torchtensor = torch.randn([100], dtype=torch.float32) < 0.5
assert torch.all(tensor_class(torchtensor).tensor() == torchtensor)
assert torch.all(Tensor.serialize(torchtensor).tensor() == torchtensor)

0 comments on commit e8ab3b9

Please sign in to comment.