|
| 1 | +import asyncio |
| 2 | +import contextlib |
1 | 3 | import re
|
| 4 | +import threading |
2 | 5 | import time
|
3 | 6 | from dataclasses import dataclass
|
4 | 7 | from typing import Any, Optional, Tuple
|
5 | 8 | from unittest import IsolatedAsyncioTestCase
|
6 | 9 | from unittest.mock import AsyncMock, MagicMock, patch
|
7 | 10 |
|
| 11 | +import aiohttp |
8 | 12 | import fastapi
|
9 | 13 | import netaddr
|
10 | 14 | import pydantic
|
11 | 15 | import pytest
|
| 16 | +import uvicorn |
12 | 17 | from fastapi.testclient import TestClient
|
13 | 18 | from starlette.requests import Request
|
14 | 19 |
|
15 |
| -from bittensor.core.axon import AxonMiddleware, Axon |
| 20 | +from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer |
16 | 21 | from bittensor.core.errors import RunException
|
17 | 22 | from bittensor.core.settings import version_as_int
|
18 | 23 | from bittensor.core.stream import StreamingSynapse
|
|
26 | 31 | )
|
27 | 32 |
|
28 | 33 |
|
29 |
| -def test_attach_initial(): |
| 34 | +def test_attach_initial(mock_get_external_ip): |
30 | 35 | # Create a mock AxonServer instance
|
31 | 36 | server = Axon()
|
32 | 37 |
|
@@ -71,7 +76,7 @@ def wrong_verify_fn(synapse: TestSynapse) -> bool:
|
71 | 76 | server.attach(forward_fn, blacklist_fn, priority_fn, wrong_verify_fn)
|
72 | 77 |
|
73 | 78 |
|
74 |
| -def test_attach(): |
| 79 | +def test_attach(mock_get_external_ip): |
75 | 80 | # Create a mock AxonServer instance
|
76 | 81 | server = Axon()
|
77 | 82 |
|
@@ -144,7 +149,7 @@ def mock_request():
|
144 | 149 |
|
145 | 150 |
|
146 | 151 | @pytest.fixture
|
147 |
| -def axon_instance(): |
| 152 | +def axon_instance(mock_get_external_ip): |
148 | 153 | axon = Axon()
|
149 | 154 | axon.required_hash_fields = {"test_endpoint": ["field1", "field2"]}
|
150 | 155 | axon.forward_class_types = {
|
@@ -329,7 +334,7 @@ async def test_verify_body_integrity_error_cases(
|
329 | 334 | (MockInfo(), "MockInfoString", "edge_case_empty_string"),
|
330 | 335 | ],
|
331 | 336 | )
|
332 |
| -def test_to_string(info_return, expected_output, test_id): |
| 337 | +def test_to_string(info_return, expected_output, test_id, mock_get_external_ip): |
333 | 338 | # Arrange
|
334 | 339 | axon = Axon()
|
335 | 340 | with patch.object(axon, "info", return_value=info_return):
|
@@ -358,7 +363,9 @@ def test_to_string(info_return, expected_output, test_id):
|
358 | 363 | ),
|
359 | 364 | ],
|
360 | 365 | )
|
361 |
| -def test_valid_ipv4_and_ipv6_address(ip, port, expected_ip_type, test_id): |
| 366 | +def test_valid_ipv4_and_ipv6_address( |
| 367 | + ip, port, expected_ip_type, test_id, mock_get_external_ip |
| 368 | +): |
362 | 369 | # Arrange
|
363 | 370 | hotkey = MockHotkey("5EemgxS7cmYbD34esCFoBgUZZC8JdnGtQvV5Qw3QFUCRRtGP")
|
364 | 371 | coldkey = MockHotkey("5EemgxS7cmYbD34esCFoBgUZZC8JdnGtQvV5Qw3QFUCRRtGP")
|
@@ -431,7 +438,14 @@ def test_invalid_ip_address(ip, port, expected_exception):
|
431 | 438 | ],
|
432 | 439 | )
|
433 | 440 | def test_axon_str_representation(
|
434 |
| - ip, port, ss58_address, started, forward_fns, expected_str, test_id |
| 441 | + ip, |
| 442 | + port, |
| 443 | + ss58_address, |
| 444 | + started, |
| 445 | + forward_fns, |
| 446 | + expected_str, |
| 447 | + test_id, |
| 448 | + mock_get_external_ip, |
435 | 449 | ):
|
436 | 450 | # Arrange
|
437 | 451 | hotkey = MockHotkey(ss58_address)
|
@@ -765,3 +779,50 @@ async def forward_fn(synapse: streaming_synapse_cls):
|
765 | 779 | "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
|
766 | 780 | },
|
767 | 781 | )
|
| 782 | + |
| 783 | + |
| 784 | +@pytest.mark.asyncio |
| 785 | +async def test_threaded_fastapi(): |
| 786 | + server_started = threading.Event() |
| 787 | + server_stopped = threading.Event() |
| 788 | + |
| 789 | + @contextlib.asynccontextmanager |
| 790 | + async def lifespan(app): |
| 791 | + server_started.set() |
| 792 | + yield |
| 793 | + server_stopped.set() |
| 794 | + |
| 795 | + app = fastapi.FastAPI( |
| 796 | + lifespan=lifespan, |
| 797 | + ) |
| 798 | + app.get("/")(lambda: "Hello World") |
| 799 | + |
| 800 | + server = FastAPIThreadedServer( |
| 801 | + uvicorn.Config(app, loop="none"), |
| 802 | + ) |
| 803 | + server.start() |
| 804 | + |
| 805 | + server_started.wait(3.0) |
| 806 | + |
| 807 | + async def wait_for_server(): |
| 808 | + while not (server.started or server_stopped.is_set()): |
| 809 | + await asyncio.sleep(1.0) |
| 810 | + |
| 811 | + await asyncio.wait_for(wait_for_server(), 7.0) |
| 812 | + |
| 813 | + assert server.is_running is True |
| 814 | + |
| 815 | + async with aiohttp.ClientSession( |
| 816 | + base_url="http://127.0.0.1:8000", |
| 817 | + ) as session: |
| 818 | + async with session.get("/") as response: |
| 819 | + assert await response.text() == '"Hello World"' |
| 820 | + |
| 821 | + server.stop() |
| 822 | + |
| 823 | + assert server.should_exit is True |
| 824 | + |
| 825 | + server_stopped.wait() |
| 826 | + |
| 827 | + with pytest.raises(aiohttp.ClientConnectorError): |
| 828 | + await session.get("/") |
0 commit comments