forked from permitio/fastapi_websocket_rpc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrpc_methods.py
99 lines (77 loc) · 2.7 KB
/
rpc_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import asyncio
import os
import sys
import typing
import copy
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from .connection_manager import ConnectionManager
from .schemas import RpcRequest, RpcResponse
from .utils import gen_uid
PING_RESPONSE = "pong"
# list of internal methods that can be called from remote
EXPOSED_BUILT_IN_METHODS = ['_ping_', '_get_channel_id_']
# NULL default value - indicating no response was received
class NoResponse:
pass
class RpcMethodsBase:
"""
The basic interface RPC channels expects method groups to implement.
- create copy of the method object
- set channel
- provide '_ping_' for keep-alive
"""
def __init__(self):
self._channel = None
def _set_channel_(self, channel):
"""
Allows the channel to share access to its functions to the methods once nested under it
"""
self._channel = channel
@property
def channel(self):
return self._channel
def _copy_(self):
""" Simple copy ctor - overriding classes may need to override copy as well."""
return copy.copy(self)
async def _ping_(self) -> str:
"""
built in ping for keep-alive
"""
return PING_RESPONSE
async def _get_channel_id_(self) -> str:
"""
built in channel id to better identify your remote
"""
return self._channel.id
class ProcessDetails(BaseModel):
pid: int = os.getpid()
cmd: typing.List[str] = sys.argv
workingdir: str = os.getcwd()
class RpcUtilityMethods(RpcMethodsBase):
"""
A simple set of RPC functions useful for management and testing
"""
def __init__(self):
"""
endpoint (WebsocketRPCEndpoint): the endpoint these methods are loaded into
"""
super().__init__()
async def get_process_details(self) -> ProcessDetails:
return ProcessDetails()
async def call_me_back(self, method_name="", args={}) -> str:
if self.channel is not None:
# generate a uid we can use to track this request
call_id = gen_uid()
# Call async - without waiting to avoid locking the event_loop
asyncio.create_task(self.channel.async_call(
method_name, args=args, call_id=call_id))
# return the id- which can be used to check the response once it's received
return call_id
async def get_response(self, call_id="") -> typing.Any:
if self.channel is not None:
res = self.channel.get_saved_response(call_id)
self.channel.clear_saved_call(call_id)
return res
async def echo(self, text: str) -> str:
return text