Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update authint version #1395

Merged
merged 6 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 3 additions & 33 deletions bittensor/_axon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dataclasses import dataclass
from substrateinterface import Keypair
import bittensor.utils.networking as net
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple

class axon:
""" Axon object for serving synapse receptors. """
Expand Down Expand Up @@ -60,7 +60,6 @@ def __init__(
max_workers: Optional[int] = None,
server: "grpc._server._Server" = None,
maximum_concurrent_rpcs: Optional[int] = None,
blacklist: Optional[Callable] = None,
) -> "bittensor.Axon":
r"""Creates a new bittensor.Axon object from passed arguments.
Args:
Expand All @@ -80,8 +79,6 @@ def __init__(
Used to create the threadpool if not passed, specifies the number of active threads servicing requests.
maximum_concurrent_rpcs (:type:`Optional[int]`, `optional`):
Maximum allowed concurrently processed RPCs.
blacklist (:obj:`Optional[callable]`, `optional`):
function to blacklist requests.
"""
self.metagraph = metagraph
self.wallet = wallet
Expand Down Expand Up @@ -111,17 +108,14 @@ def __init__(
self.external_ip = self.config.axon.external_ip if self.config.axon.external_ip != None else bittensor.utils.networking.get_external_ip()
self.external_port = self.config.axon.external_port if self.config.axon.external_port != None else self.config.axon.port
self.full_address = str(self.config.axon.ip) + ":" + str(self.config.axon.port)
self.blacklist = blacklist
self.started = False

# Build priority thread pool
self.priority_threadpool = bittensor.prioritythreadpool(config=self.config.axon)

# Build interceptor.
self.receiver_hotkey = self.wallet.hotkey.ss58_address
self.auth_interceptor = AuthInterceptor(
receiver_hotkey=self.receiver_hotkey, blacklist=self.blacklist
)
self.auth_interceptor = AuthInterceptor(receiver_hotkey=self.receiver_hotkey)

# Build grpc server
if server is None:
Expand Down Expand Up @@ -280,18 +274,14 @@ class AuthInterceptor(grpc.ServerInterceptor):
def __init__(
self,
receiver_hotkey: str,
blacklist: Callable = None,
):
r"""Creates a new server interceptor that authenticates incoming messages from passed arguments.
Args:
receiver_hotkey(str):
the SS58 address of the hotkey which should be targeted by RPCs
black_list (Function, `optional`):
black list function that prevents certain pubkeys from sending messages
"""
super().__init__()
self.nonces = {}
self.blacklist = blacklist
self.receiver_hotkey = receiver_hotkey


Expand All @@ -315,7 +305,7 @@ def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]
version = metadata.get('bittensor-version')
if signature is None:
raise Exception("Request signature missing")
if int(version) < 370:
if int(version) < 510:
raise Exception("Incorrect Version")
parts = self.parse_signature_v2(signature)
if parts is not None:
Expand Down Expand Up @@ -348,25 +338,8 @@ def check_signature(
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce

def black_list_checking(self, hotkey: str, method: str):
r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey"""
if self.blacklist is None:
return

request_type = {
"/Bittensor/Forward": bittensor.proto.RequestType.FORWARD,
"/Bittensor/Backward": bittensor.proto.RequestType.BACKWARD,
}.get(method)
if request_type is None:
raise Exception("Unknown request type")

failed, error_message = self.blacklist(hotkey, request_type)
if failed:
raise Exception(str(error_message))

def intercept_service(self, continuation, handler_call_details):
r"""Authentication between bittensor nodes. Intercepts messages and checks them"""
method = handler_call_details.method
metadata = dict(handler_call_details.invocation_metadata)

try:
Expand All @@ -382,9 +355,6 @@ def intercept_service(self, continuation, handler_call_details):
nonce, sender_hotkey, signature, receptor_uuid
)

# blacklist checking
self.black_list_checking(sender_hotkey, method)

return continuation(handler_call_details)

except Exception as e:
Expand Down
14 changes: 11 additions & 3 deletions bittensor/_synapse/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,22 @@ class SynapseCall( ABC ):
def __init__(
self,
synapse: 'bittensor.Synapse',
request_proto: object
request_proto: object,
context: grpc.ServicerContext,
):
metadata = dict(context.invocation_metadata())
(
_,
sender_hotkey,
_,
_,
) = synapse.axon.auth_interceptor.parse_signature(metadata)

self.completed = False
self.start_time = time.time()
self.timeout = request_proto.timeout
self.src_version = request_proto.version
self.src_hotkey = request_proto.hotkey
self.src_hotkey = sender_hotkey
self.dest_hotkey = synapse.axon.wallet.hotkey.ss58_address
self.dest_version = bittensor.__version_as_int__
self.return_code: bittensor.proto.ReturnCode = bittensor.proto.ReturnCode.Success
Expand Down Expand Up @@ -171,4 +180,3 @@ def apply( self, call: SynapseCall ) -> object:
call.end()
call.log_outbound()
return call._get_response_proto()

16 changes: 9 additions & 7 deletions bittensor/_synapse/text_prompting/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def __init__(
synapse: "bittensor.TextPromptingSynapseMulti",
request_proto: bittensor.proto.MultiForwardTextPromptingRequest,
multi_forward_callback: Callable,
context: grpc.ServicerContext
):
super().__init__( synapse = synapse, request_proto = request_proto )
super().__init__( synapse = synapse, request_proto = request_proto, context = context )
self.messages: List[ Dict[str, str] ] = request_proto.messages
self.formatted_messages = [ json.loads(message) for message in self.messages ]
self.multi_forward_callback = multi_forward_callback
Expand Down Expand Up @@ -67,8 +68,9 @@ def __init__(
synapse: "TextPromptingSynapse",
request_proto: bittensor.proto.ForwardTextPromptingRequest,
forward_callback: Callable,
context: grpc.ServicerContext
):
super().__init__( synapse = synapse, request_proto = request_proto )
super().__init__( synapse = synapse, request_proto = request_proto, context = context )
self.messages = request_proto.messages
self.formatted_messages = [ json.loads(message) for message in self.messages ]
self.forward_callback = forward_callback
Expand Down Expand Up @@ -99,8 +101,9 @@ def __init__(
synapse: "TextPromptingSynapse",
request_proto: bittensor.proto.BackwardTextPromptingRequest,
backward_callback: Callable,
context: grpc.ServicerContext
):
super().__init__( synapse = synapse, request_proto = request_proto )
super().__init__( synapse = synapse, request_proto = request_proto, context = context )
self.formatted_messages = [ message for message in request_proto.messages ]
self.formatted_rewards = torch.tensor( [ request_proto.rewards ], dtype = torch.float32 )
self.completion = request_proto.response
Expand Down Expand Up @@ -140,17 +143,16 @@ def multi_forward( self, messages: List[Dict[str, str]] ) -> List[ str ]: ...
def backward( self, messages: List[Dict[str, str]], response: str, rewards: torch.FloatTensor ) -> str: ...

def Forward( self, request: bittensor.proto.ForwardTextPromptingRequest, context: grpc.ServicerContext ) -> bittensor.proto.ForwardTextPromptingResponse:
call = SynapseForward( self, request, self.forward )
call = SynapseForward( self, request, self.forward, context )
bittensor.logging.trace( 'Forward: {} '.format( call ) )
return self.apply( call = call )

def MultiForward( self, request: bittensor.proto.MultiForwardTextPromptingRequest, context: grpc.ServicerContext ) -> bittensor.proto.MultiForwardTextPromptingResponse:
call = SynapseForwardMulti( self, request, self.multi_forward )
call = SynapseForwardMulti( self, request, self.multi_forward, context )
bittensor.logging.trace( 'MultiForward: {} '.format( call ) )
return self.apply( call = call )

def Backward( self, request: bittensor.proto.BackwardTextPromptingRequest, context: grpc.ServicerContext ) -> bittensor.proto.BackwardTextPromptingResponse:
call = SynapseBackward( self, request, self.backward )
call = SynapseBackward( self, request, self.backward, context )
bittensor.logging.trace( 'Backward: {}'.format( call ) )
return self.apply( call = call )

20 changes: 18 additions & 2 deletions tests/unit_tests/bittensor_tests/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,30 @@ def test_text_prompting_synapse_backward():
def test_text_prompting_synapse_blacklist():
synapse = get_synapse()
request = bittensor.proto.ForwardTextPromptingRequest()
call = bittensor._synapse.text_prompting.synapse.SynapseForward( synapse, request, synapse.forward )

# Mock the signature checking of the context.
context = MagicMock()
context.invocation_metadata.return_value = {}
synapse.axon = MagicMock()
synapse.axon.auth_interceptor = MagicMock()
synapse.axon.auth_interceptor.parse_signature.return_value = (None, None, "5CtstubuSoVLJGCXkiWRNKrrGg2DVBZ9qMs2qYTLsZR4q1Wg", None)

call = bittensor._synapse.text_prompting.synapse.SynapseForward( synapse, request, synapse.forward, context = context )
blacklist = synapse.blacklist( call )
assert blacklist == False

def test_text_prompting_synapse_priority():
synapse = get_synapse()
request = bittensor.proto.ForwardTextPromptingRequest()
call = bittensor._synapse.text_prompting.synapse.SynapseForward( synapse, request, synapse.forward )

# Mock the signature checking of the context.
context = MagicMock()
context.invocation_metadata.return_value = {}
synapse.axon = MagicMock()
synapse.axon.auth_interceptor = MagicMock()
synapse.axon.auth_interceptor.parse_signature.return_value = (None, None, "5CtstubuSoVLJGCXkiWRNKrrGg2DVBZ9qMs2qYTLsZR4q1Wg", None)

call = bittensor._synapse.text_prompting.synapse.SynapseForward( synapse, request, synapse.forward, context = context )
priority = synapse.priority( call )
assert priority == 0.0