Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch bad endpoint protocol #1296

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions bittensor/_endpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
MAXPORT = 65535
MAXUID = 4294967295
ACCEPTABLE_IPTYPES = [4,6,0]
ACCEPTABLE_PROTOCOLS = [0] # TODO
ACCEPTABLE_PROTOCOLS = [0,4] # TODO
ENDPOINT_BUFFER_SIZE = 250

class endpoint:
Expand Down Expand Up @@ -107,7 +107,7 @@ def from_neuron( neuron: Union['bittensor.NeuronInfo', 'bittensor.NeuronInfoLite
def from_dict(endpoint_dict: dict) -> 'bittensor.Endpoint':
""" Return an endpoint with spec from dictionary
"""
endpoint.assert_format(
if not endpoint.assert_format(
version = endpoint_dict['version'],
uid = endpoint_dict['uid'],
hotkey = endpoint_dict['hotkey'],
Expand All @@ -116,7 +116,8 @@ def from_dict(endpoint_dict: dict) -> 'bittensor.Endpoint':
ip_type = endpoint_dict['ip_type'],
protocol = endpoint_dict['protocol'],
coldkey = endpoint_dict['coldkey']
)
):
raise ValueError('Invalid endpoint dict')
return endpoint_impl.Endpoint(
version = endpoint_dict['version'],
uid = endpoint_dict['uid'],
Expand Down Expand Up @@ -152,7 +153,10 @@ def from_tensor( tensor: torch.LongTensor) -> 'bittensor.Endpoint':
endpoint_bytes = bytearray( endpoint_list )
endpoint_string = endpoint_bytes.decode('utf-8')
endpoint_dict = json.loads( endpoint_string )
return endpoint.from_dict(endpoint_dict)
try:
return endpoint.from_dict(endpoint_dict)
except ValueError:
return endpoint.dummy()

@staticmethod
def dummy():
Expand All @@ -170,18 +174,21 @@ def assert_format(
coldkey:str
) -> bool:
""" Asserts that the endpoint has a valid format
Raises:
Multiple assertion errors.
"""
assert version >= 0, 'endpoint version must be positive. - got {}'.format(version)
assert version <= MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(version)
assert uid >= 0 and uid <= MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(uid)
assert len(ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(ip)
assert ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(ip_type)
assert port >= 0 and port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(port)
assert len(coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(coldkey)
assert len(hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(hotkey)
# TODO
assert protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(protocol)
try:
assert version >= 0, 'endpoint version must be positive. - got {}'.format(version)
assert version <= MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(version)
assert uid >= 0 and uid <= MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(uid)
assert len(ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(ip)
assert ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(ip_type)
assert port >= 0 and port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(port)
assert len(coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(coldkey)
assert len(hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(hotkey)
# TODO
assert protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(protocol)

return True
except AssertionError:
return False


27 changes: 15 additions & 12 deletions bittensor/_endpoint/endpoint_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,23 @@ def __init__( self, version: int, uid:int, hotkey:str, ip:str, ip_type:int, port
self.modality = modality


def assert_format( self ):
def assert_format( self ) -> bool:
""" Asserts that the endpoint has a valid format
Raises:
Multiple assertion errors.
"""
assert self.version > 0, 'endpoint version must be positive. - got {}'.format(self.version)
assert self.version < MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(self.version)
assert self.uid >= 0 and self.uid < MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(self.uid)
assert len(self.ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(self.ip)
assert self.ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(self.ip_type)
assert self.port > 0 and self.port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(self.port)
assert len(self.coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(self.coldkey)
assert len(self.hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(self.hotkey)
assert self.protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(self.protocol)
try:
assert self.version > 0, 'endpoint version must be positive. - got {}'.format(self.version)
assert self.version < MAX_VERSION, 'endpoint version must be less than 999. - got {}'.format(self.version)
assert self.uid >= 0 and self.uid < MAXUID, 'endpoint uid must positive and be less than u32 max: 4294967295. - got {}'.format(self.uid)
assert len(self.ip) < MAX_IP_LENGTH, 'endpoint ip string must have length less than 8*4. - got {}'.format(self.ip)
assert self.ip_type in ACCEPTABLE_IPTYPES, 'endpoint ip_type must be either 4 or 6.- got {}'.format(self.ip_type)
assert self.port > 0 and self.port < MAXPORT , 'port must be positive and less than 65535 - got {}'.format(self.port)
assert len(self.coldkey) == SS58_LENGTH, 'coldkey string must be length 48 - got {}'.format(self.coldkey)
assert len(self.hotkey) == SS58_LENGTH, 'hotkey string must be length 48 - got {}'.format(self.hotkey)
assert self.protocol in ACCEPTABLE_PROTOCOLS, 'protocol must be 0 (for now) - got {}'.format(self.protocol)

return True
except AssertionError as e:
return False

@property
def is_serving(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/bittensor_tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_create_endpoint():
protocol = 0,
)
assert endpoint.check_format() == True
endpoint.assert_format()
assert endpoint.assert_format()
assert endpoint.version == bittensor.__version_as_int__
assert endpoint.uid == 0
assert endpoint.ip == '0.0.0.0'
Expand Down