diff --git a/tests/unit_tests/test_axon.py b/tests/unit_tests/test_axon.py index 512b4635ae..f521eabc09 100644 --- a/tests/unit_tests/test_axon.py +++ b/tests/unit_tests/test_axon.py @@ -1,18 +1,23 @@ +import asyncio +import contextlib import re +import threading import time from dataclasses import dataclass from typing import Any, Optional, Tuple from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import fastapi import netaddr import pydantic import pytest +import uvicorn from fastapi.testclient import TestClient from starlette.requests import Request -from bittensor.core.axon import AxonMiddleware, Axon +from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer from bittensor.core.errors import RunException from bittensor.core.settings import version_as_int from bittensor.core.stream import StreamingSynapse @@ -765,3 +770,50 @@ async def forward_fn(synapse: streaming_synapse_cls): "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", }, ) + + +@pytest.mark.asyncio +async def test_threaded_fastapi(): + server_started = threading.Event() + server_stopped = threading.Event() + + @contextlib.asynccontextmanager + async def lifespan(app): + server_started.set() + yield + server_stopped.set() + + app = fastapi.FastAPI( + lifespan=lifespan, + ) + app.get("/")(lambda: "Hello World") + + server = FastAPIThreadedServer( + uvicorn.Config(app, loop="none"), + ) + server.start() + + server_started.wait(3.0) + + async def wait_for_server(): + while not (server.started or server_stopped.is_set()): + await asyncio.sleep(1.0) + + await asyncio.wait_for(wait_for_server(), 7.0) + + assert server.is_running is True + + async with aiohttp.ClientSession( + base_url="http://127.0.0.1:8000", + ) as session: + async with session.get("/") as response: + assert await response.text() == '"Hello World"' + + server.stop() + + assert server.should_exit is True + + server_stopped.wait() + + with pytest.raises(aiohttp.ClientConnectorError): + await session.get("/")