Skip to content

Commit

Permalink
cleanup dendrite multi_forward
Browse files Browse the repository at this point in the history
  • Loading branch information
ifrit98 committed Jun 8, 2023
1 parent 2042122 commit 8fb3aab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 67 deletions.
37 changes: 7 additions & 30 deletions bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def turn_console_off():

from bittensor._proto.bittensor_pb2 import ForwardTextPromptingRequest
from bittensor._proto.bittensor_pb2 import ForwardTextPromptingResponse
from bittensor._proto.bittensor_pb2 import MultiForwardTextPromptingRequest
from bittensor._proto.bittensor_pb2 import MultiForwardTextPromptingResponse
from bittensor._proto.bittensor_pb2 import BackwardTextPromptingRequest
from bittensor._proto.bittensor_pb2 import BackwardTextPromptingResponse

Expand All @@ -208,12 +206,6 @@ def turn_console_off():
from bittensor._dendrite.text_prompting.dendrite import TextPromptingDendrite as text_prompting
from bittensor._dendrite.text_prompting.dendrite_pool import TextPromptingDendritePool as text_prompting_pool

# ---- Base Miners -----
from bittensor._neuron.base_miner_neuron import BaseMinerNeuron
from bittensor._neuron.base_validator import BaseValidator
from bittensor._neuron.base_prompting_miner import BasePromptingMiner
from bittensor._neuron.base_huggingface_miner import HuggingFaceMiner

# ---- Errors and Exceptions -----
from bittensor._keyfile.keyfile_impl import KeyFileError as KeyFileError

Expand Down Expand Up @@ -322,19 +314,11 @@ def forward(
return_all: bool = False,
) -> Union[str, List[str]]:
roles, messages = self.format_content( content )
if not return_all:
return self._dendrite.forward(
roles = roles,
messages = messages,
timeout = timeout
).completion
else:
return self._dendrite.multi_forward(
roles = roles,
messages = messages,
timeout = timeout
).multi_completions

return self._dendrite.forward(
roles = roles,
messages = messages,
timeout = timeout
).completion

async def async_forward(
self,
Expand All @@ -343,18 +327,11 @@ async def async_forward(
return_all: bool = False,
) -> Union[str, List[str]]:
roles, messages = self.format_content( content )
if not return_all:
return await self._dendrite.async_forward(
roles = roles,
messages = messages,
timeout = timeout
).completion
else:
return self._dendrite.async_multi_forward(
return await self._dendrite.async_forward(
roles = roles,
messages = messages,
timeout = timeout
).multi_completions
).completion

class BittensorLLM(LLM):
"""Wrapper around Bittensor Prompting Subnetwork.
Expand Down
38 changes: 1 addition & 37 deletions bittensor/_dendrite/text_prompting/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import grpc
import json
import torch
import asyncio
import bittensor
from typing import Callable, List, Dict, Union
from typing import Callable, List, Union

class DendriteForwardCall( bittensor.DendriteCall ):

Expand Down Expand Up @@ -198,40 +196,6 @@ async def async_forward(
if return_call: return forward_call
else: return forward_call.completion

def multi_forward(
self,
roles: List[ str ] ,
messages: List[ str ],
timeout: float = bittensor.__blocktime__,
return_call:bool = True,
) -> Union[ str, DendriteForwardCall ]:
forward_call = MultiDendriteForwardCall(
dendrite = self,
messages = messages,
roles = roles,
timeout = timeout,
)
response_call = self.loop.run_until_complete( self.apply( dendrite_call = forward_call ) )
if return_call: return response_call
else: return response_call.multi_completions

async def async_multi_forward(
self,
roles: List[ str ],
messages: List[ str ],
timeout: float = bittensor.__blocktime__,
return_call: bool = True,
) -> Union[ str, DendriteForwardCall ]:
forward_call = MultiDendriteForwardCall(
dendrite = self,
messages = messages,
roles = roles,
timeout = timeout,
)
forward_call = await self.apply( dendrite_call = forward_call )
if return_call: return forward_call
else: return forward_call.multi_completions

def backward(
self,
roles: List[ str ],
Expand Down

0 comments on commit 8fb3aab

Please sign in to comment.