From aedc92ebafe939255188ba2a88a4b3e9de5ddc51 Mon Sep 17 00:00:00 2001 From: Haochen Pan Date: Fri, 13 Sep 2024 12:29:30 -0400 Subject: [PATCH] OctopusQueues impl, update project dep and run_parallel_workflow.py flake --- mofa/octopus.py | 224 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 5 +- run_parallel_workflow.py | 6 +- 3 files changed, 230 insertions(+), 5 deletions(-) create mode 100644 mofa/octopus.py diff --git a/mofa/octopus.py b/mofa/octopus.py new file mode 100644 index 00000000..bcb7f08b --- /dev/null +++ b/mofa/octopus.py @@ -0,0 +1,224 @@ +"""Queues that use Octopus""" + +from typing import Collection, Optional, Union, Dict, Tuple +import logging + +import os +import json + +import pickle + +from colmena.exceptions import TimeoutException, KillSignalException +from colmena.models import SerializationMethod + +from colmena.queue.base import ColmenaQueues +from diaspora_event_sdk import KafkaProducer +from diaspora_event_sdk import KafkaConsumer +from kafka.errors import KafkaError + +from time import time + + +logger = logging.getLogger(__name__) + + +def value_serializer(v): + return json.dumps(v).encode("utf-8") + + +def value_deserializer(x): + return json.loads(x.decode("utf-8")) + + +class OctopusQueues(ColmenaQueues): + def __init__( + self, + topics: Collection[str], + prefix: str = "mofa_test1", + auto_offset_reset: str = "earliest", + discard_events_before: int = int(time() * 1000), + serialization_method: Union[ + str, SerializationMethod + ] = SerializationMethod.PICKLE, + keep_inputs: bool = True, + proxystore_name: Optional[Union[str, Dict[str, str]]] = None, + proxystore_threshold: Optional[Union[int, Dict[str, int]]] = None, + ): + assert os.environ["OCTOPUS_AWS_ACCESS_KEY_ID"] + assert os.environ["OCTOPUS_AWS_SECRET_ACCESS_KEY"] + assert os.environ["OCTOPUS_BOOTSTRAP_SERVERS"] + super().__init__( + topics, + serialization_method, + keep_inputs, + proxystore_name, + proxystore_threshold, + ) + # self.topics in handled in super + self.prefix = prefix + self.auto_offset_reset = auto_offset_reset + self.discard_events_before = discard_events_before + + self.request_producer = None + self.request_consumer = None + self.result_consumers = {} + + def connect_request_producer(self): + """Connect the request producer.""" + if not isinstance(self.request_producer, KafkaProducer): + self.request_producer = KafkaProducer( + value_serializer=value_serializer, + ) + + def connect_request_consumer(self): + """Connect the request consumer.""" + if not isinstance(self.request_consumer, KafkaConsumer): + request_topic = f"{self.prefix}_requests" + self.request_consumer = KafkaConsumer( + request_topic, + auto_offset_reset=self.auto_offset_reset, + value_deserializer=value_deserializer, + max_poll_records=1, + ) + + def connect_result_consumer(self, topic): + """Connect a result consumer for a specific topic.""" + if (topic not in self.result_consumers) or not isinstance( + self.result_consumers[topic], KafkaConsumer + ): + result_topic = f"{self.prefix}_{topic}_result" + self.result_consumers[topic] = KafkaConsumer( + result_topic, + auto_offset_reset=self.auto_offset_reset, + value_deserializer=value_deserializer, + max_poll_records=1, + ) + + def disconnect_request_producer(self): + """Disconnect the request producer.""" + if self.request_producer: + self.request_producer.close() + + def disconnect_request_consumer(self): + """Disconnect the request consumer.""" + if self.request_consumer: + self.request_consumer.close() + self.request_consumer = None + + def disconnect_result_consumer(self, topic): + """Disconnect the result consumer for a specific topic.""" + consumer = self.result_consumers.pop(topic, None) + if consumer: + consumer.close() + + def __getstate__(self): + state = super().__getstate__() + + if self.request_producer: + state["request_producer"] = "connected" + + if self.request_consumer: + state["request_consumer"] = "connected" + + for topic in list(self.result_consumers.keys()): + state["result_consumers"][topic] = "connected" + + return state + + def __setstate__(self, state): + super().__setstate__(state) + + if self.request_producer: + self.connect_request_producer() + + if self.request_consumer: + self.connect_request_consumer() + + for topic in list(self.result_consumers.keys()): + self.connect_result_consumer(topic) + + def _publish_event(self, message, octopus_topic): + try: + future = self.request_producer.send(octopus_topic, message) + res = future.get(timeout=5) + self.request_producer.flush() + return res + except KafkaError as e: + print(f"Error producing message: {e}") + + def _send_request(self, message: str, topic: str): + self.connect_request_producer() + queue = f"{self.prefix}_requests" + event = {"message": message, "topic": topic} + self._publish_event(event, queue) + + def _get_message( + self, + consumer, + timeout: float = None, + ): + if timeout is None: + timeout = 0 + timeout *= 1000 # to ms + assert consumer, "should be initialized" + + try: + while True: # blocks indefinitely + events = consumer.poll(timeout) + for tp, es in events.items(): + ts, event = es[0].timestamp, es[0].value + if ts < self.discard_events_before: + continue + + return event + + except KafkaError as e: + print(f"Error consuming message: {e}") + raise TimeoutException() + + def _get_request(self, timeout: float = None) -> Tuple[str, str]: + self.connect_request_consumer() + + event = self._get_message(self.request_consumer, timeout) + if event["message"].endswith("null"): + raise KillSignalException() + + topic, message = event["topic"], event["message"] + return topic, message + + def _send_result(self, message: str, topic: str): + self.connect_request_producer() + queue = f"{self.prefix}_{topic}_result" + self._publish_event(message, queue) + + def _get_result(self, topic: str, timeout: int = None) -> str: + self.connect_result_consumer(topic) + consumer = self.result_consumers.get(topic) + if not consumer: + raise ConnectionError( + f"No consumer connected for topic '{topic}'. Did you call 'connect_result_consumer('{topic}')'?" + ) + + event = self._get_message(consumer, timeout) + return event + + +if __name__ == "__main__": + queues = OctopusQueues( + topics=["generation", "lammps", "cp2k", "training", "assembly"] + ) + print(queues) + print(queues.topics) + + queues.connect_request_producer() + queues.connect_request_consumer() + for topic in queues.topics: + queues.connect_result_consumer(topic) + + queues_dumped = pickle.dumps(queues) + print(queues_dumped) + + queues_loaded = pickle.loads(queues_dumped) + print(queues_loaded.request_producer) + print(queues_loaded.request_consumer) + print(queues_loaded.result_consumers) diff --git a/pyproject.toml b/pyproject.toml index e5d78ef3..41000d31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,10 @@ dependencies = [ "gpustat", # Data management - "pymongo>=4" + "pymongo>=4", + + # Mini-app + "diaspora-event-sdk[kafka-python]" ] [tool.setuptools.packages.find] diff --git a/run_parallel_workflow.py b/run_parallel_workflow.py index 19f13e1d..98c53927 100644 --- a/run_parallel_workflow.py +++ b/run_parallel_workflow.py @@ -32,7 +32,6 @@ from colmena.models import Result from colmena.task_server.parsl import ParslTaskServer from colmena.queue import ColmenaQueues -from colmena.queue.redis import RedisQueues from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent from mofa.assembly.assemble import assemble_many @@ -47,6 +46,7 @@ from mofa.hpc.colmena import DiffLinkerInference from mofa import db as mofadb from mofa.hpc.config import configs as hpc_configs, HPCConfig +from mofa.octopus import OctopusQueues RDLogger.DisableLog('rdApp.*') ob.obErrorLog.SetOutputLevel(0) @@ -627,9 +627,7 @@ def store_cp2k(self, result: Result): store = Store(name='redis', connector=RedisConnector(hostname=args.redis_host, port=6379), metrics=True) register_store(store) - # Configure to a use Redis queue, which allows streaming results form other nodes - queues = RedisQueues( - hostname=args.redis_host, + queues = OctopusQueues( topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'], proxystore_name='redis', proxystore_threshold=args.proxy_threshold