Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revolution aio #1488

Merged
merged 24 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
af49ddf
aiohttp
ifrit98 Aug 8, 2023
d66ce18
swap req for httpx/aiohttp
ifrit98 Aug 8, 2023
adf863a
(WIP)
ifrit98 Aug 9, 2023
eee1a07
add processing on dendrite side that calls stream synapse
ifrit98 Aug 17, 2023
0d1dae5
add example that works, fix var name
ifrit98 Aug 17, 2023
56ce93f
remove print statements
ifrit98 Aug 17, 2023
4df2486
merge revolution threadpool updates into aio
ifrit98 Aug 17, 2023
9690a41
integrate streaming into bittensor
ifrit98 Aug 23, 2023
fde18d0
Merge branch 'revolution' into revolution_aio
ifrit98 Aug 23, 2023
fe87e39
run black and remove examples
ifrit98 Aug 23, 2023
088815f
drop support for python<3.9
ifrit98 Aug 23, 2023
36f6425
merge revolution back in with fixed tests
ifrit98 Aug 23, 2023
c37eba6
Merge branch 'revolution' into revolution_aio
ifrit98 Aug 24, 2023
89ef61a
defer to default deserialize() instead of abstractmethod, make loggin…
ifrit98 Aug 25, 2023
a7bdc79
Merge branch 'revolution' into revolution_aio
ifrit98 Aug 25, 2023
6c5051a
Merge branch 'revolution' into revolution_aio
ifrit98 Aug 29, 2023
5499b82
Merge branch 'revolution-more-info' into revolution_aio
ifrit98 Aug 31, 2023
6015333
add client pooling for efficiency
ifrit98 Aug 31, 2023
dcd1a1d
add annotations so we don't have to use string type-hints, run black
ifrit98 Aug 31, 2023
b21f0d2
Aio merge master (#1507)
ifrit98 Sep 1, 2023
4e741a0
Merge branch 'revolution' into revolution_aio
ifrit98 Sep 1, 2023
7d054fb
Merge branch 'revolution' into revolution_aio
ifrit98 Sep 1, 2023
3fef8f1
merge master and update reqs
ifrit98 Sep 1, 2023
51ab1d8
remove trace calls
ifrit98 Sep 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ $ pip3 install bittensor
```
3. From source:
```bash
$ git clone --recurse-submodules https://github.com/opentensor/bittensor.git
$ git clone https://github.com/opentensor/bittensor.git
$ python3 -m pip install -e bittensor/
```
4. Using Conda (recommended for **Apple M1**):
Expand All @@ -41,16 +41,6 @@ $ conda env create -f ~/.bittensor/bittensor/scripts/environments/apple_m1_envir
$ conda activate bittensor
```

To sync the submodules bittensor-wallet and bittensor-config:
```bash
cd bittensor/

git submodule sync && git submodule update --init

# Reinstall with updated submodules
python3 -m pip install -e .
```

To test your installation, type:
```bash
$ btcli --help
Expand Down
1 change: 1 addition & 0 deletions bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def debug(on: bool = True):
from .threadpool import PriorityThreadPoolExecutor as PriorityThreadPoolExecutor

from .synapse import *
from .stream import *
from .tensor import *
from .axon import axon as axon
from .dendrite import dendrite as dendrite
Expand Down
74 changes: 51 additions & 23 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from __future__ import annotations

import asyncio
import uuid
import time
import torch
import httpx
import aiohttp
import bittensor as bt
from typing import Union, Optional, List

Expand Down Expand Up @@ -53,7 +55,7 @@ class dendrite(torch.nn.Module):
>>> d( bt.axon(), bt.Synapse )
"""

def __init__(self, wallet: Optional[Union["bt.wallet", "bt.keypair"]] = None):
def __init__(self, wallet: Optional[Union[bt.wallet, bt.keypair]] = None):
"""
Initializes the Dendrite object, setting up essential properties.

Expand All @@ -68,9 +70,6 @@ def __init__(self, wallet: Optional[Union["bt.wallet", "bt.keypair"]] = None):
# Unique identifier for the instance
self.uuid = str(uuid.uuid1())

# HTTP client for making requests
self.client = httpx.AsyncClient()

# Get the external IP
self.external_ip = bt.utils.networking.get_external_ip()

Expand All @@ -81,6 +80,19 @@ def __init__(self, wallet: Optional[Union["bt.wallet", "bt.keypair"]] = None):

self.synapse_history: list = []

self._session: aiohttp.ClientSession = None

@property
async def session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession()
return self._session

async def close_session(self):
if self._session:
await self._session.close()
self._session = None

def query(self, *args, **kwargs):
"""
Makes a synchronous request to multiple target Axons and returns the server responses.
Expand Down Expand Up @@ -108,9 +120,7 @@ def query(self, *args, **kwargs):

async def forward(
self,
axons: Union[
List[Union["bt.AxonInfo", "bt.axon"]], Union["bt.AxonInfo", "bt.axon"]
],
axons: Union[List[Union[bt.AxonInfo, bt.axon]], Union[bt.AxonInfo, bt.axon]],
synapse: bt.Synapse = bt.Synapse(),
timeout: float = 12,
deserialize: bool = True,
Expand Down Expand Up @@ -183,7 +193,7 @@ async def query_all_axons():

async def call(
self,
target_axon: Union["bt.AxonInfo", "bt.axon"],
target_axon: Union[bt.AxonInfo, bt.axon],
synapse: bt.Synapse = bt.Synapse(),
timeout: float = 12.0,
deserialize: bool = True,
Expand Down Expand Up @@ -229,25 +239,40 @@ async def call(
)

# Make the HTTP POST request
json_response = await self.client.post(
url, headers=synapse.to_headers(), json=synapse.dict(), timeout=timeout
)

# Process the server response
self.process_server_response(json_response, synapse)
async with (await self.session).post(
url,
headers=synapse.to_headers(),
json=synapse.dict(),
timeout=timeout,
) as response:
if (
response.headers.get("Content-Type", "").lower()
== "text/event-stream".lower()
): # identify streaming response
bt.logging.trace("Streaming response detected.")
await synapse.process_streaming_response(
response
) # process the entire streaming response
json_response = synapse.extract_response_json(response)
else:
bt.logging.trace("Non-streaming response detected.")
json_response = await response.json()

# Process the server response
self.process_server_response(response, json_response, synapse)

# Set process time and log the response
synapse.dendrite.process_time = str(time.time() - start_time)
bt.logging.debug(
f"dendrite | <-- | {synapse.get_total_size()} B | {synapse.name} | {synapse.axon.hotkey} | {synapse.axon.ip}:{str(synapse.axon.port)} | {synapse.axon.status_code} | {synapse.axon.status_message}"
)

except httpx.ConnectError as e:
except aiohttp.ClientConnectorError as e:
synapse.dendrite.status_code = "503"
synapse.dendrite.status_message = f"Service at {synapse.axon.ip}:{str(synapse.axon.port)}/{request_name} unavailable."

except httpx.TimeoutException as e:
synapse.dendrite.status_code = "406"
except asyncio.TimeoutError as e:
synapse.dendrite.status_code = "408"
synapse.dendrite.status_message = f"Timedout after {timeout} seconds."

except Exception as e:
Expand All @@ -272,7 +297,7 @@ async def call(

def preprocess_synapse_for_request(
self,
target_axon_info: "bt.AxonInfo",
target_axon_info: bt.AxonInfo,
synapse: bt.Synapse,
timeout: float = 12.0,
) -> bt.Synapse:
Expand Down Expand Up @@ -320,13 +345,16 @@ def preprocess_synapse_for_request(

return synapse

def process_server_response(self, server_response, local_synapse: bt.Synapse):
def process_server_response(
self, server_response, json_response, local_synapse: bt.Synapse
):
"""
Processes the server response, updates the local synapse state with the
server's state and merges headers set by the server.

Args:
server_response (object): The response object from the server.
server_response (object): The aiohttp response object from the server.
json_response (dict): The parsed JSON response from the server.
local_synapse (bt.Synapse): The local synapse object to be updated.

Raises:
Expand All @@ -335,11 +363,11 @@ def process_server_response(self, server_response, local_synapse: bt.Synapse):
bt.logging.trace("Postprocess server response")

# Check if the server responded with a successful status code
if server_response.status_code == 200:
if server_response.status == 200:
# If the response is successful, overwrite local synapse state with
# server's state only if the protocol allows mutation. To prevent overwrites,
# the protocol must set allow_mutation = False
server_synapse = local_synapse.__class__(**server_response.json())
server_synapse = local_synapse.__class__(**json_response)
for key in local_synapse.dict().keys():
try:
# Set the attribute in the local synapse from the corresponding
Expand Down
147 changes: 147 additions & 0 deletions bittensor/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import bittensor

from starlette.responses import StreamingResponse as _StreamingResponse
from starlette.types import Send
from typing import Callable, Awaitable, List
from pydantic import BaseModel
from abc import ABC, abstractmethod


class BTStreamingResponseModel(BaseModel):
"""
BTStreamingResponseModel is a Pydantic model that encapsulates the token streamer callable for Pydantic validation.
It is used within the StreamingSynapse class to create a BTStreamingResponse object, which is responsible for handling
the streaming of tokens.

The token streamer is a callable that takes a send function and returns an awaitable. It is responsible for generating
the content of the streaming response, typically by processing tokens and sending them to the client.

This model ensures that the token streamer conforms to the expected signature and provides a clear interface for
passing the token streamer to the BTStreamingResponse class.

Attributes:
token_streamer: Callable[[Send], Awaitable[None]]
The token streamer callable, which takes a send function (provided by the ASGI server) and returns an awaitable.
It is responsible for generating the content of the streaming response.
"""

token_streamer: Callable[[Send], Awaitable[None]]


class StreamingSynapse(bittensor.Synapse, ABC):
"""
The StreamingSynapse class is designed to be subclassed for handling streaming responses in the Bittensor network.
It provides abstract methods that must be implemented by the subclass to deserialize, process streaming responses,
and extract JSON data. It also includes a method to create a streaming response object.
"""

class Config:
validate_assignment = True

class BTStreamingResponse(_StreamingResponse):
"""
BTStreamingResponse is a specialized subclass of the Starlette StreamingResponse designed to handle the streaming
of tokens within the Bittensor network. It is used internally by the StreamingSynapse class to manage the response
streaming process, including sending headers and calling the token streamer provided by the subclass.

This class is not intended to be directly instantiated or modified by developers subclassing StreamingSynapse.
Instead, it is used by the create_streaming_response method to create a response object based on the token streamer
provided by the subclass.
"""

def __init__(self, model: BTStreamingResponseModel, **kwargs) -> None:
"""
Initializes the BTStreamingResponse with the given token streamer model.

Args:
model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible
for generating the content of the response.
**kwargs: Additional keyword arguments passed to the parent StreamingResponse class.
"""
super().__init__(content=iter(()), **kwargs)
self.token_streamer = model.token_streamer

async def stream_response(self, send: Send) -> None:
"""
Asynchronously streams the response by sending headers and calling the token streamer.

This method is responsible for initiating the response by sending the appropriate headers, including the
content type for event-streaming. It then calls the token streamer to generate the content and sends the
response body to the client.

Args:
send: A callable to send the response, provided by the ASGI server.
"""
bittensor.logging.trace("Streaming response.")

headers = [(b"content-type", b"text/event-stream")] + self.raw_headers

await send(
{"type": "http.response.start", "status": 200, "headers": headers}
)

await self.token_streamer(send)

await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope, receive, send):
"""
Asynchronously calls the stream_response method, allowing the BTStreamingResponse object to be used as an ASGI
application.

This method is part of the ASGI interface and is called by the ASGI server to handle the request and send the
response. It delegates to the stream_response method to perform the actual streaming process.

Args:
scope: The scope of the request, containing information about the client, server, and request itself.
receive: A callable to receive the request, provided by the ASGI server.
send: A callable to send the response, provided by the ASGI server.
"""
await self.stream_response(send)

@abstractmethod
async def process_streaming_response(self, response):
"""
Abstract method that must be implemented by the subclass.
This method should provide logic to handle the streaming response, such as parsing and accumulating data.
It is called as the response is being streamed from the network, and should be implemented to handle the specific
streaming data format and requirements of the subclass.

Args:
response: The response object to be processed, typically containing chunks of data.
"""
...

@abstractmethod
def extract_response_json(self, response):
"""
Abstract method that must be implemented by the subclass.
This method should provide logic to extract JSON data from the response, including headers and content.
It is called after the response has been processed and is responsible for retrieving structured data
that can be used by the application.

Args:
response: The response object from which to extract JSON data.
"""
...

def create_streaming_response(
self, token_streamer: Callable[[Send], Awaitable[None]]
):
"""
Creates a streaming response using the provided token streamer.
This method can be used by the subclass to create a response object that can be sent back to the client.
The token streamer should be implemented to generate the content of the response according to the specific
requirements of the subclass.

Args:
token_streamer: A callable that takes a send function and returns an awaitable. It's responsible for generating the content of the response.

Returns:
BTStreamingResponse: The streaming response object, ready to be sent to the client.
"""
bittensor.logging.trace("Creating streaming response.")

model_instance = BTStreamingResponseModel(token_streamer=token_streamer)

return self.BTStreamingResponse(model_instance)
16 changes: 14 additions & 2 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ class Config:
validate_assignment = True

def deserialize(self) -> "Synapse":
"""
Deserializes the Synapse object.

This method is intended to be overridden by subclasses for custom deserialization logic.
In the context of the Synapse superclass, this method simply returns the instance itself.
When inheriting from this class, subclasses should provide their own implementation for
deserialization if specific deserialization behavior is desired.

By default, if a subclass does not provide its own implementation of this method, the
Synapse's deserialize method will be used, returning the object instance as-is.

Returns:
Synapse: The deserialized Synapse object. In this default implementation, it returns the object itself.
"""
return self

@pydantic.root_validator(pre=True)
Expand Down Expand Up @@ -552,7 +566,6 @@ def to_headers(self) -> dict:
headers[f"bt_header_dict_tensor_{field}"] = str(serialized_dict_tensor)

elif required and field in required:
bittensor.logging.trace(f"Serializing {field} with json...")
try:
serialized_value = json.dumps(value)
encoded_value = base64.b64encode(serialized_value.encode()).decode(
Expand Down Expand Up @@ -656,7 +669,6 @@ def parse_headers_to_inputs(cls, headers: dict) -> dict:
continue
# Handle 'input_obj' headers
elif "bt_header_input_obj" in key:
bittensor.logging.trace(f"Deserializing {key} with json...")
try:
new_key = key.split("bt_header_input_obj_")[1]
# Skip if the key already exists in the dictionary
Expand Down
Loading