Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: Adds mypy to typecheck during circle ci checks #1705

Merged
merged 8 commits into from
Feb 21, 2024
6 changes: 6 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ jobs:
. env/bin/activate
python -m flake8 bittensor/ --count

- run:
name: Type check with mypy
command: |
. env/bin/activate
python -m mypy --ignore-missing-imports bittensor/

unit-tests-all-python-versions:
docker:
- image: cimg/python:3.10
Expand Down
222 changes: 137 additions & 85 deletions bittensor/axon.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions bittensor/chain_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,9 @@ class NeuronInfo:
validator_permit: bool
weights: List[List[int]]
bonds: List[List[int]]
prometheus_info: "PrometheusInfo"
axon_info: "AxonInfo"
pruning_score: int
prometheus_info: Optional["PrometheusInfo"] = None
axon_info: Optional[AxonInfo] = None
is_null: bool = False

@classmethod
Expand Down
87 changes: 39 additions & 48 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch
import aiohttp
import bittensor
from fastapi import Response
from typing import Union, Optional, List, Union, AsyncGenerator, Any


Expand Down Expand Up @@ -96,7 +95,7 @@ class dendrite(torch.nn.Module):
"""

def __init__(
self, wallet: Optional[Union[bittensor.wallet, bittensor.keypair]] = None
self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None
):
"""
Initializes the Dendrite object, setting up essential properties.
Expand All @@ -121,7 +120,7 @@ def __init__(

self.synapse_history: list = []

self._session: aiohttp.ClientSession = None
self._session: Optional[aiohttp.ClientSession] = None

@property
async def session(self) -> aiohttp.ClientSession:
Expand Down Expand Up @@ -291,11 +290,8 @@ def _log_incoming_response(self, synapse):

def query(
self, *args, **kwargs
) -> Union[
bittensor.Synapse,
List[bittensor.Synapse],
bittensor.StreamingSynapse,
List[bittensor.StreamingSynapse],
) -> List[
Union[AsyncGenerator[Any, Any], bittensor.Synapse, bittensor.StreamingSynapse]
]:
"""
Makes a synchronous request to multiple target Axons and returns the server responses.
Expand All @@ -322,7 +318,7 @@ def query(
new_loop.close()
finally:
self.close_session()
return result
return result # type: ignore

async def forward(
self,
Expand All @@ -336,7 +332,7 @@ async def forward(
run_async: bool = True,
streaming: bool = False,
) -> List[
Union[AsyncGenerator[Any], bittensor.Synapse, bittensor.StreamingSynapse]
Union[AsyncGenerator[Any, Any], bittensor.Synapse, bittensor.StreamingSynapse]
]:
"""
Asynchronously sends requests to one or multiple Axons and collates their responses.
Expand Down Expand Up @@ -402,7 +398,9 @@ async def forward(

async def query_all_axons(
is_stream: bool,
) -> Union[AsyncGenerator[Any], bittensor.Synapse, bittensor.StreamingSynapse]:
) -> Union[
AsyncGenerator[Any, Any], bittensor.Synapse, bittensor.StreamingSynapse
]:
"""
Handles the processing of requests to all targeted axons, accommodating both streaming and non-streaming responses.

Expand All @@ -423,7 +421,7 @@ async def query_all_axons(
async def single_axon_response(
target_axon,
) -> Union[
AsyncGenerator[Any], bittensor.Synapse, bittensor.StreamingSynapse
AsyncGenerator[Any, Any], bittensor.Synapse, bittensor.StreamingSynapse
]:
"""
Manages the request and response process for a single axon, supporting both streaming and non-streaming modes.
Expand All @@ -446,15 +444,15 @@ async def single_axon_response(
# If in streaming mode, return the async_generator
return self.call_stream(
target_axon=target_axon,
synapse=synapse.copy(),
synapse=synapse.copy(), # type: ignore
timeout=timeout,
deserialize=deserialize,
)
else:
# If not in streaming mode, simply call the axon and get the response.
return await self.call(
target_axon=target_axon,
synapse=synapse.copy(),
synapse=synapse.copy(), # type: ignore
timeout=timeout,
deserialize=deserialize,
)
Expand All @@ -463,19 +461,16 @@ async def single_axon_response(
if not run_async:
return [
await single_axon_response(target_axon) for target_axon in axons
]
] # type: ignore
# If run_async flag is True, get responses concurrently using asyncio.gather().
return await asyncio.gather(
*(single_axon_response(target_axon) for target_axon in axons)
)
) # type: ignore

# Get responses for all axons.
responses = await query_all_axons(streaming)
# Return the single response if only one axon was targeted, else return all responses
if len(responses) == 1 and not is_list:
return responses[0]
else:
return responses
return responses[0] if len(responses) == 1 and not is_list else responses # type: ignore

async def call(
self,
Expand Down Expand Up @@ -533,7 +528,7 @@ async def call(
self.process_server_response(response, json_response, synapse)

# Set process time and log the response
synapse.dendrite.process_time = str(time.time() - start_time)
synapse.dendrite.process_time = str(time.time() - start_time) # type: ignore

except Exception as e:
self._handle_request_errors(synapse, request_name, e)
Expand All @@ -555,10 +550,10 @@ async def call(
async def call_stream(
self,
target_axon: Union[bittensor.AxonInfo, bittensor.axon],
synapse: bittensor.Synapse = bittensor.Synapse(),
synapse: bittensor.StreamingSynapse = bittensor.Synapse(), # type: ignore
timeout: float = 12.0,
deserialize: bool = True,
) -> AsyncGenerator[Any]:
) -> AsyncGenerator[Any, Any]:
"""
Sends a request to a specified Axon and yields streaming responses.

Expand Down Expand Up @@ -596,7 +591,7 @@ async def call_stream(
url = f"http://{endpoint}/{request_name}"

# Preprocess synapse for making a request
synapse = self.preprocess_synapse_for_request(target_axon, synapse, timeout)
synapse = self.preprocess_synapse_for_request(target_axon, synapse, timeout) # type: ignore

try:
# Log outgoing request
Expand All @@ -609,16 +604,16 @@ async def call_stream(
json=synapse.dict(),
timeout=timeout,
) as response:
# Use synapse subclass' process_streaming_response method to yield the response chunks
async for chunk in synapse.process_streaming_response(response):
# Use async for loop to yield the response chunks
async for chunk in response.content.iter_any():
yield chunk # Yield each chunk as it's processed
json_response = synapse.extract_response_json(response)

# Process the server response
self.process_server_response(response, json_response, synapse)

# Set process time and log the response
synapse.dendrite.process_time = str(time.time() - start_time)
synapse.dendrite.process_time = str(time.time() - start_time) # type: ignore

except Exception as e:
self._handle_request_errors(synapse, request_name, e)
Expand Down Expand Up @@ -657,26 +652,22 @@ def preprocess_synapse_for_request(
bittensor.Synapse: The preprocessed synapse.
"""
# Set the timeout for the synapse
synapse.timeout = str(timeout)
synapse.timeout = timeout

# Build the Dendrite headers using the local system's details
synapse.dendrite = bittensor.TerminalInfo(
**{
"ip": str(self.external_ip),
"version": str(bittensor.__version_as_int__),
"nonce": f"{time.monotonic_ns()}",
"uuid": str(self.uuid),
"hotkey": str(self.keypair.ss58_address),
}
ip=self.external_ip,
version=bittensor.__version_as_int__,
nonce=time.monotonic_ns(),
uuid=self.uuid,
hotkey=self.keypair.ss58_address,
)

# Build the Axon headers using the target axon's details
synapse.axon = bittensor.TerminalInfo(
**{
"ip": str(target_axon_info.ip),
"port": str(target_axon_info.port),
"hotkey": str(target_axon_info.hotkey),
}
ip=target_axon_info.ip,
port=target_axon_info.port,
hotkey=target_axon_info.hotkey,
)

# Sign the request using the dendrite, axon info, and the synapse body hash
Expand All @@ -687,7 +678,7 @@ def preprocess_synapse_for_request(

def process_server_response(
self,
server_response: Response,
server_response: aiohttp.ClientResponse,
json_response: dict,
local_synapse: bittensor.Synapse,
):
Expand Down Expand Up @@ -719,27 +710,27 @@ def process_server_response(
pass

# Extract server headers and overwrite None values in local synapse headers
server_headers = bittensor.Synapse.from_headers(server_response.headers)
server_headers = bittensor.Synapse.from_headers(server_response.headers) # type: ignore

# Merge dendrite headers
local_synapse.dendrite.__dict__.update(
{
**local_synapse.dendrite.dict(exclude_none=True),
**server_headers.dendrite.dict(exclude_none=True),
**local_synapse.dendrite.dict(exclude_none=True), # type: ignore
**server_headers.dendrite.dict(exclude_none=True), # type: ignore
}
)

# Merge axon headers
local_synapse.axon.__dict__.update(
{
**local_synapse.axon.dict(exclude_none=True),
**server_headers.axon.dict(exclude_none=True),
**local_synapse.axon.dict(exclude_none=True), # type: ignore
**server_headers.axon.dict(exclude_none=True), # type: ignore
}
)

# Update the status code and status message of the dendrite to match the axon
local_synapse.dendrite.status_code = local_synapse.axon.status_code
local_synapse.dendrite.status_message = local_synapse.axon.status_message
local_synapse.dendrite.status_code = local_synapse.axon.status_code # type: ignore
local_synapse.dendrite.status_message = local_synapse.axon.status_message # type: ignore

def __str__(self) -> str:
"""
Expand Down
6 changes: 6 additions & 0 deletions bittensor/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,9 @@ class InternalServerError(Exception):
r"""This exception is raised when the requested function fails on the server. Indicates a server error."""

pass


class SynapseDendriteNoneException(Exception):
def __init__(self, message="Synapse Dendrite is None"):
self.message = message
super().__init__(self.message)
6 changes: 3 additions & 3 deletions bittensor/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from aiohttp import ClientResponse
import bittensor

from starlette.responses import StreamingResponse as _StreamingResponse
from starlette.responses import Response
from starlette.types import Send, Receive, Scope
from typing import Callable, Awaitable
from pydantic import BaseModel
Expand Down Expand Up @@ -98,7 +98,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
await self.stream_response(send)

@abstractmethod
async def process_streaming_response(self, response: Response):
async def process_streaming_response(self, response: ClientResponse):
"""
Abstract method that must be implemented by the subclass.
This method should provide logic to handle the streaming response, such as parsing and accumulating data.
Expand All @@ -111,7 +111,7 @@ async def process_streaming_response(self, response: Response):
...

@abstractmethod
def extract_response_json(self, response: Response) -> dict:
def extract_response_json(self, response: ClientResponse) -> dict:
"""
Abstract method that must be implemented by the subclass.
This method should provide logic to extract JSON data from the response, including headers and content.
Expand Down
18 changes: 18 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[mypy]
ignore_missing_imports = True
ignore_errors = True

[mypy-*.axon.*]
ignore_errors = False

[mypy-*.dendrite.*]
ignore_errors = False

; [mypy-*.metagraph.*] uncomment went mypy passes
; ignore_errors = False

; [mypy-*.subtensor.*] uncomment went mypy passes
; ignore_errors = False

; [mypy-*.synapse.*] uncomment went mypy passes
; ignore_errors = False
Loading