Skip to content

Commit

Permalink
Ensure request and response sequences remain aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
bryananderson committed Jan 29, 2025
1 parent 0bbf935 commit fff1b25
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
16 changes: 14 additions & 2 deletions pyht/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ async def lease_factory() -> Lease:
self._api_key = api_key
self._inference_coordinates: Optional[Dict[str, Any]] = None
self._ws: Optional[ClientConnection] = None
self._ws_requests_sent = 0
self._ws_responses_received = 0

if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023:
self._max_attempts = 3
Expand Down Expand Up @@ -454,21 +456,31 @@ async def _tts_ws(
ws_address = self._inference_coordinates[voice_engine]["websocket_url"]
if self._ws is None:
self._ws = await connect(ws_address)
self._ws_requests_sent = 0
self._ws_responses_received = 0
try:
await self._ws.send(json.dumps(json_data))
self._ws_requests_sent += 1
except ConnectionClosed as e:
logging.debug(f"Reconnecting websocket which closed unexpectedly: {e}")
self._ws = await connect(ws_address)
self._ws_requests_sent = 0
self._ws_responses_received = 0
await self._ws.send(json.dumps(json_data))
self._ws_requests_sent += 1
chunk_idx = -1
request_id = -1
started = False
async for chunk in self._ws:
if isinstance(chunk, str):
msg = json.loads(chunk)
if msg["type"] == "start":
started = True
request_id = msg["request_id"]
self._ws_responses_received += 1
if self._ws_responses_received == self._ws_requests_sent:
started = True
request_id = msg["request_id"]
elif self._ws_responses_received > self._ws_requests_sent:
raise Exception("Received more responses than requests")
elif msg["type"] == "end" and msg["request_id"] == request_id:
break
else:
Expand Down
16 changes: 14 additions & 2 deletions pyht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def lease_factory() -> Lease:
self._api_key = api_key
self._inference_coordinates: Optional[Dict[str, Any]] = None
self._ws: Optional[ClientConnection] = None
self._ws_requests_sent = 0
self._ws_responses_received = 0

if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023:
self._max_attempts = 3
Expand Down Expand Up @@ -723,12 +725,18 @@ def _tts_ws(
ws_address = self._inference_coordinates[voice_engine]["websocket_url"]
if self._ws is None:
self._ws = connect(ws_address)
self._ws_requests_sent = 0
self._ws_responses_received = 0
try:
self._ws.send(json.dumps(json_data))
self._ws_requests_sent += 1
except ConnectionClosed as e:
logging.debug(f"Reconnecting websocket which closed unexpectedly: {e}")
self._ws = connect(ws_address)
self._ws_requests_sent = 0
self._ws_responses_received = 0
self._ws.send(json.dumps(json_data))
self._ws_requests_sent += 1
chunk_idx = -1
request_id = -1
started = False
Expand All @@ -737,8 +745,12 @@ def _tts_ws(
if isinstance(chunk, str):
msg = json.loads(chunk)
if msg["type"] == "start":
started = True
request_id = msg["request_id"]
self._ws_responses_received += 1
if self._ws_responses_received == self._ws_requests_sent:
started = True
request_id = msg["request_id"]
elif self._ws_responses_received > self._ws_requests_sent:
raise Exception("Received more responses than requests")
elif msg["type"] == "end" and msg["request_id"] == request_id:
break
else:
Expand Down

0 comments on commit fff1b25

Please sign in to comment.