diff --git a/bittensor/_endpoint/__init__.py b/bittensor/_endpoint/__init__.py index 2d21d04f4b..d3256d5f82 100644 --- a/bittensor/_endpoint/__init__.py +++ b/bittensor/_endpoint/__init__.py @@ -30,7 +30,7 @@ MAXPORT = 65535 MAXUID = 4294967295 ACCEPTABLE_IPTYPES = [4,6,0] -ACCEPTABLE_PROTOCOLS = [0] # TODO +ACCEPTABLE_PROTOCOLS = [0,4] # TODO ENDPOINT_BUFFER_SIZE = 250 class endpoint: @@ -107,7 +107,7 @@ def from_neuron( neuron: Union['bittensor.NeuronInfo', 'bittensor.NeuronInfoLite def from_dict(endpoint_dict: dict) -> 'bittensor.Endpoint': """ Return an endpoint with spec from dictionary """ - endpoint.assert_format( + if not endpoint.assert_format( version = endpoint_dict['version'], uid = endpoint_dict['uid'], hotkey = endpoint_dict['hotkey'], @@ -116,7 +116,8 @@ def from_dict(endpoint_dict: dict) -> 'bittensor.Endpoint': ip_type = endpoint_dict['ip_type'], protocol = endpoint_dict['protocol'], coldkey = endpoint_dict['coldkey'] - ) + ): + raise ValueError('Invalid endpoint dict') return endpoint_impl.Endpoint( version = endpoint_dict['version'], uid = endpoint_dict['uid'], @@ -152,7 +153,10 @@ def from_tensor( tensor: torch.LongTensor) -> 'bittensor.Endpoint': endpoint_bytes = bytearray( endpoint_list ) endpoint_string = endpoint_bytes.decode('utf-8') endpoint_dict = json.loads( endpoint_string ) - return endpoint.from_dict(endpoint_dict) + try: + return endpoint.from_dict(endpoint_dict) + except ValueError: + return endpoint.dummy() @staticmethod def dummy(): @@ -170,18 +174,21 @@ def assert_format( coldkey:str ) -> bool: """ Asserts that the endpoint has a valid format - Raises: - Multiple assertion errors. """ - assert version >= 0, 'endpoint version must be positive. - got {}'.format(version) - assert version <= MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(version) - assert uid >= 0 and uid <= MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(uid) - assert len(ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(ip) - assert ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(ip_type) - assert port >= 0 and port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(port) - assert len(coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(coldkey) - assert len(hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(hotkey) - # TODO - assert protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(protocol) + try: + assert version >= 0, 'endpoint version must be positive. - got {}'.format(version) + assert version <= MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(version) + assert uid >= 0 and uid <= MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(uid) + assert len(ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(ip) + assert ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(ip_type) + assert port >= 0 and port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(port) + assert len(coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(coldkey) + assert len(hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(hotkey) + # TODO + assert protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(protocol) + + return True + except AssertionError: + return False diff --git a/bittensor/_endpoint/endpoint_impl.py b/bittensor/_endpoint/endpoint_impl.py index 0505242856..e124205b89 100644 --- a/bittensor/_endpoint/endpoint_impl.py +++ b/bittensor/_endpoint/endpoint_impl.py @@ -48,20 +48,23 @@ def __init__( self, version: int, uid:int, hotkey:str, ip:str, ip_type:int, port self.modality = modality - def assert_format( self ): + def assert_format( self ) -> bool: """ Asserts that the endpoint has a valid format - Raises: - Multiple assertion errors. """ - assert self.version > 0, 'endpoint version must be positive. - got {}'.format(self.version) - assert self.version < MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(self.version) - assert self.uid >= 0 and self.uid < MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(self.uid) - assert len(self.ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(self.ip) - assert self.ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(self.ip_type) - assert self.port > 0 and self.port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(self.port) - assert len(self.coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(self.coldkey) - assert len(self.hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(self.hotkey) - assert self.protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(self.protocol) + try: + assert self.version > 0, 'endpoint version must be positive. - got {}'.format(self.version) + assert self.version < MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(self.version) + assert self.uid >= 0 and self.uid < MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(self.uid) + assert len(self.ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(self.ip) + assert self.ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(self.ip_type) + assert self.port > 0 and self.port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(self.port) + assert len(self.coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(self.coldkey) + assert len(self.hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(self.hotkey) + assert self.protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(self.protocol) + + return True + except AssertionError as e: + return False @property def is_serving(self) -> bool: diff --git a/tests/unit_tests/bittensor_tests/test_endpoint.py b/tests/unit_tests/bittensor_tests/test_endpoint.py index cf40ddadaf..3c3d7fa05a 100644 --- a/tests/unit_tests/bittensor_tests/test_endpoint.py +++ b/tests/unit_tests/bittensor_tests/test_endpoint.py @@ -38,7 +38,7 @@ def test_create_endpoint(): protocol = 0, ) assert endpoint.check_format() == True - endpoint.assert_format() + assert endpoint.assert_format() assert endpoint.version == bittensor.__version_as_int__ assert endpoint.uid == 0 assert endpoint.ip == '0.0.0.0'