-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add request-response services and topic (#1)
* Add first implementation of request response bridge * Improve license * Fix bridge class * Fix colcon tests * Add CI * Fix tests for ROS2 rolling
- Loading branch information
1 parent
d2629cc
commit 8b5a1b3
Showing
7 changed files
with
372 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: Build | ||
|
||
on: | ||
workflow_dispatch: | ||
push: | ||
branches: ["main"] | ||
pull_request: | ||
branches: ["main"] | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-24.04 | ||
permissions: | ||
contents: read | ||
pull-requests: read | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
distro: [jazzy, rolling] | ||
env: | ||
ROS_DISTRO: ${{ matrix.distro }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
path: src/hyveos | ||
- uses: actions/checkout@v4 | ||
with: | ||
repository: p2p-industries/hyveos_ros_msgs | ||
ssh-key: ${{ secrets.MSGS_REPO_DEPLOY_KEY }} | ||
path: src/hyveos_msgs | ||
- name: Build and test workspace | ||
uses: ichiro-its/ros2-ws-action@v1.0.1 | ||
with: | ||
distro: ${{ matrix.distro }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from abc import ABC, abstractmethod | ||
import asyncio | ||
from pathlib import Path | ||
from signal import SIGINT, SIGTERM | ||
import traceback | ||
|
||
from hyveos_sdk import Connection, OpenedConnection | ||
import rclpy | ||
from rclpy.node import Node | ||
|
||
|
||
def service_callback(f): | ||
async def inner_wrapper(self, request, response): | ||
try: | ||
return await f(self, request, response) | ||
except Exception as e: | ||
response.success = False | ||
response.error = str(e) | ||
return response | ||
|
||
# Allows calling __await__ repeatedly on awaitables that require waiting for a future before | ||
# doing so (e.g. asyncio). This will make asyncio functions compatible with rclpy | ||
# implementation of async. See https://github.com/ros2/rclpy/issues/962 for more info. | ||
async def wrapper(self, request, response): | ||
coro = inner_wrapper(self, request, response) | ||
try: | ||
while True: | ||
future = coro.send(None) | ||
assert asyncio.isfuture(future) or future is None, \ | ||
'Unexpected awaitable behavior. Only use rclpy or asyncio awaitables.' | ||
if future is None: | ||
# coro is rclpy-style awaitable; await is expected to be called repeatedly. | ||
await asyncio.sleep(0) | ||
continue | ||
while not future.done(): | ||
# coro is asyncio-style awaitable; stop calling await until future is done. | ||
await asyncio.sleep(0) # yields None | ||
future.result() | ||
except StopIteration as e: | ||
return e.value | ||
|
||
return wrapper | ||
|
||
|
||
def prepare_data(data: bytes | list[bytes]) -> bytes: | ||
if isinstance(data, bytes): | ||
return data | ||
elif isinstance(data, list): | ||
return b''.join(data) | ||
else: | ||
raise ValueError('Invalid data') | ||
|
||
|
||
class BridgeClient(ABC): | ||
|
||
@abstractmethod | ||
def __init__(self, node: 'Bridge'): | ||
pass | ||
|
||
@abstractmethod | ||
async def run(self): | ||
pass | ||
|
||
|
||
class Bridge(Node): | ||
connection: OpenedConnection | ||
bridge_clients: list[BridgeClient] | ||
|
||
def __init__(self, connection: OpenedConnection): | ||
super().__init__('hyveos_bridge') | ||
|
||
from .reqres import ReqResClient as _ # noqa: F401 | ||
|
||
self.connection = connection | ||
self.bridge_clients = [client(self) for client in BridgeClient.__subclasses__()] | ||
|
||
for client in self.bridge_clients: | ||
self.get_logger().info(f'Initializing {client.__class__.__name__}') | ||
|
||
async def run(self): | ||
coroutines = [client.run() for client in self.bridge_clients] | ||
await asyncio.gather(*coroutines) | ||
|
||
|
||
async def ros_loop(node: Node): | ||
while rclpy.ok(): | ||
rclpy.spin_once(node, timeout_sec=0) | ||
await asyncio.sleep(1e-4) | ||
|
||
|
||
async def async_main(args=None): | ||
def find_bridge_path(name: str) -> Path: | ||
candidates = ['/run', '/var/run', '/tmp'] | ||
|
||
for candidate in candidates: | ||
path = Path(candidate) / 'hyved' / 'bridge' / name | ||
if path.exists(): | ||
return path | ||
|
||
raise FileNotFoundError(f'Bridge {name} not found') | ||
|
||
socket_path = find_bridge_path('bridge.sock') | ||
shared_dir_path = find_bridge_path('files') | ||
|
||
async with Connection(socket_path=socket_path, shared_dir_path=shared_dir_path) as connection: | ||
try: | ||
rclpy.init(args=args) | ||
|
||
bridge = Bridge(connection) | ||
|
||
await asyncio.gather(ros_loop(bridge), bridge.run()) | ||
except asyncio.CancelledError: | ||
print('Exiting...') | ||
except Exception: | ||
traceback.print_exc() | ||
finally: | ||
if rclpy.ok(): | ||
bridge.destroy_node() | ||
rclpy.shutdown() | ||
|
||
|
||
def main(args=None): | ||
loop = asyncio.get_event_loop() | ||
main_task = asyncio.ensure_future(async_main(args=args)) | ||
for signal in [SIGINT, SIGTERM]: | ||
loop.add_signal_handler(signal, main_task.cancel) | ||
try: | ||
loop.run_until_complete(main_task) | ||
finally: | ||
loop.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import asyncio | ||
|
||
from hyveos_msgs.msg import ReceivedRequest | ||
from hyveos_msgs.srv import RequestSubscription, Respond, SendRequest | ||
from hyveos_sdk import ManagedStream, RequestResponseService | ||
from hyveos_sdk.protocol.script_pb2 import RecvRequest | ||
from rclpy.impl.rcutils_logger import RcutilsLogger | ||
from rclpy.publisher import Publisher | ||
|
||
from .bridge import Bridge, BridgeClient, prepare_data, service_callback | ||
|
||
|
||
class Subscription: | ||
event: asyncio.Event | ||
task: asyncio.Task | ||
logger: RcutilsLogger | ||
|
||
def __init__( | ||
self, | ||
stream: ManagedStream[RecvRequest], | ||
publisher: Publisher, | ||
logger: RcutilsLogger | ||
): | ||
self.event = asyncio.Event() | ||
self.task = asyncio.create_task(self.run(stream, publisher)) | ||
self.logger = logger | ||
|
||
async def run(self, stream: ManagedStream[RecvRequest], publisher: Publisher): | ||
async with stream: | ||
iterator = stream.__aiter__() | ||
|
||
while True: | ||
data_task = asyncio.create_task(iterator.__anext__()) | ||
event_task = asyncio.create_task(self.event.wait()) | ||
|
||
done, _ = await asyncio.wait( | ||
[data_task, event_task], | ||
return_when=asyncio.FIRST_COMPLETED | ||
) | ||
|
||
if data_task in done: | ||
request = data_task.result() | ||
|
||
self.logger.info(f'Received request {request.seq} from {request.peer.peer_id}') | ||
|
||
request_msg = ReceivedRequest() | ||
request_msg.peer = request.peer.peer_id | ||
if request.msg.topic.topic is None: | ||
request_msg.topic = '' | ||
request_msg.no_topic = True | ||
else: | ||
request_msg.topic = request.msg.topic.topic.topic | ||
request_msg.no_topic = False | ||
request_msg.data = request.msg.data.data | ||
request_msg.seq = request.seq | ||
|
||
publisher.publish(request_msg) | ||
|
||
if event_task in done: | ||
break | ||
|
||
async def cancel(self): | ||
self.event.set() | ||
await self.task | ||
|
||
|
||
class ReqResClient(BridgeClient): | ||
logger: RcutilsLogger | ||
req_res: RequestResponseService | ||
subscriptions: dict[str | None, Subscription] | ||
subscriptions_lock: asyncio.Lock | ||
|
||
def __init__(self, node: Bridge): | ||
def namespaced(name: str) -> str: | ||
return f'{node.get_name()}/req_res/{name}' | ||
|
||
self.received_requests_publisher = node.create_publisher( | ||
ReceivedRequest, | ||
namespaced('received_requests'), | ||
10 | ||
) | ||
self.send_request_service = node.create_service( | ||
SendRequest, | ||
namespaced('send_request'), | ||
self._send_request_callback | ||
) | ||
self.subscribe_service = node.create_service( | ||
RequestSubscription, | ||
namespaced('subscribe'), | ||
self._subscribe_callback | ||
) | ||
self.unsubscribe_service = node.create_service( | ||
RequestSubscription, | ||
namespaced('unsubscribe'), | ||
self._unsubscribe_callback | ||
) | ||
self.respond_service = node.create_service( | ||
Respond, | ||
namespaced('respond'), | ||
self._respond_callback | ||
) | ||
|
||
self.logger = node.get_logger() | ||
self.req_res = node.connection.get_request_response_service() | ||
self.subscriptions = {} | ||
self.subscriptions_lock = asyncio.Lock() | ||
|
||
@service_callback | ||
async def _send_request_callback( | ||
self, | ||
request: SendRequest.Request, | ||
response: SendRequest.Response | ||
): | ||
if request.no_topic: | ||
self.logger.info(f'Sending request without topic to {request.peer}') | ||
topic = None | ||
else: | ||
self.logger.info(f'Sending request with topic {request.topic} to {request.peer}') | ||
topic = request.topic | ||
|
||
data = prepare_data(request.data) | ||
|
||
res = await self.req_res.send_request(request.peer, data, topic=topic) | ||
|
||
if res.WhichOneof('response') == 'data': | ||
response.success = True | ||
response.response = res.data.data | ||
return response | ||
elif res.WhichOneof('response') == 'error': | ||
raise ValueError(res.error) | ||
else: | ||
raise ValueError('Invalid response') | ||
|
||
@service_callback | ||
async def _subscribe_callback( | ||
self, | ||
request: RequestSubscription.Request, | ||
response: RequestSubscription.Response | ||
): | ||
if request.no_topic: | ||
self.logger.info('Subscribing to messages without topic') | ||
topic = None | ||
else: | ||
self.logger.info(f'Subscribing to messages with topic {request.topic}') | ||
topic = request.topic | ||
|
||
async with self.subscriptions_lock: | ||
if topic not in self.subscriptions: | ||
stream = self.req_res.receive(query=topic) | ||
self.subscriptions[topic] = Subscription( | ||
stream, | ||
self.received_requests_publisher, | ||
self.logger | ||
) | ||
else: | ||
raise ValueError('Already subscribed to topic') | ||
|
||
response.success = True | ||
return response | ||
|
||
@service_callback | ||
async def _unsubscribe_callback( | ||
self, | ||
request: RequestSubscription.Request, | ||
response: RequestSubscription.Response | ||
): | ||
if request.no_topic: | ||
self.logger.info('Unsubscribing from messages without topic') | ||
topic = None | ||
else: | ||
self.logger.info(f'Unsubscribing from messages with topic {request.topic}') | ||
topic = request.topic | ||
|
||
async with self.subscriptions_lock: | ||
if topic in self.subscriptions: | ||
await self.subscriptions.pop(topic).cancel() | ||
else: | ||
raise ValueError('Not subscribed to topic') | ||
|
||
response.success = True | ||
return response | ||
|
||
@service_callback | ||
async def _respond_callback(self, request: Respond.Request, response: Respond.Response): | ||
self.logger.info(f'Responding to request {request.seq}') | ||
|
||
if request.success: | ||
data = prepare_data(request.response) | ||
await self.req_res.respond(request.seq, data) | ||
else: | ||
await self.req_res.respond(request.seq, b'', error=request.error) | ||
|
||
response.success = True | ||
return response | ||
|
||
async def run(self): | ||
pass |
Oops, something went wrong.