Skip to content

Commit

Permalink
OctopusQueues impl, update project dep and run_parallel_workflow.py
Browse files Browse the repository at this point in the history
flake
  • Loading branch information
haochenpan committed Sep 19, 2024
1 parent 5a95d66 commit aedc92e
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 5 deletions.
224 changes: 224 additions & 0 deletions mofa/octopus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""Queues that use Octopus"""

from typing import Collection, Optional, Union, Dict, Tuple
import logging

import os
import json

import pickle

from colmena.exceptions import TimeoutException, KillSignalException
from colmena.models import SerializationMethod

from colmena.queue.base import ColmenaQueues
from diaspora_event_sdk import KafkaProducer
from diaspora_event_sdk import KafkaConsumer
from kafka.errors import KafkaError

from time import time


logger = logging.getLogger(__name__)


def value_serializer(v):
return json.dumps(v).encode("utf-8")


def value_deserializer(x):
return json.loads(x.decode("utf-8"))


class OctopusQueues(ColmenaQueues):
def __init__(
self,
topics: Collection[str],
prefix: str = "mofa_test1",
auto_offset_reset: str = "earliest",
discard_events_before: int = int(time() * 1000),
serialization_method: Union[
str, SerializationMethod
] = SerializationMethod.PICKLE,
keep_inputs: bool = True,
proxystore_name: Optional[Union[str, Dict[str, str]]] = None,
proxystore_threshold: Optional[Union[int, Dict[str, int]]] = None,
):
assert os.environ["OCTOPUS_AWS_ACCESS_KEY_ID"]
assert os.environ["OCTOPUS_AWS_SECRET_ACCESS_KEY"]
assert os.environ["OCTOPUS_BOOTSTRAP_SERVERS"]
super().__init__(
topics,
serialization_method,
keep_inputs,
proxystore_name,
proxystore_threshold,
)
# self.topics in handled in super
self.prefix = prefix
self.auto_offset_reset = auto_offset_reset
self.discard_events_before = discard_events_before

self.request_producer = None
self.request_consumer = None
self.result_consumers = {}

def connect_request_producer(self):
"""Connect the request producer."""
if not isinstance(self.request_producer, KafkaProducer):
self.request_producer = KafkaProducer(
value_serializer=value_serializer,
)

def connect_request_consumer(self):
"""Connect the request consumer."""
if not isinstance(self.request_consumer, KafkaConsumer):
request_topic = f"{self.prefix}_requests"
self.request_consumer = KafkaConsumer(
request_topic,
auto_offset_reset=self.auto_offset_reset,
value_deserializer=value_deserializer,
max_poll_records=1,
)

def connect_result_consumer(self, topic):
"""Connect a result consumer for a specific topic."""
if (topic not in self.result_consumers) or not isinstance(
self.result_consumers[topic], KafkaConsumer
):
result_topic = f"{self.prefix}_{topic}_result"
self.result_consumers[topic] = KafkaConsumer(
result_topic,
auto_offset_reset=self.auto_offset_reset,
value_deserializer=value_deserializer,
max_poll_records=1,
)

def disconnect_request_producer(self):
"""Disconnect the request producer."""
if self.request_producer:
self.request_producer.close()

def disconnect_request_consumer(self):
"""Disconnect the request consumer."""
if self.request_consumer:
self.request_consumer.close()
self.request_consumer = None

def disconnect_result_consumer(self, topic):
"""Disconnect the result consumer for a specific topic."""
consumer = self.result_consumers.pop(topic, None)
if consumer:
consumer.close()

def __getstate__(self):
state = super().__getstate__()

if self.request_producer:
state["request_producer"] = "connected"

if self.request_consumer:
state["request_consumer"] = "connected"

for topic in list(self.result_consumers.keys()):
state["result_consumers"][topic] = "connected"

return state

def __setstate__(self, state):
super().__setstate__(state)

if self.request_producer:
self.connect_request_producer()

if self.request_consumer:
self.connect_request_consumer()

for topic in list(self.result_consumers.keys()):
self.connect_result_consumer(topic)

def _publish_event(self, message, octopus_topic):
try:
future = self.request_producer.send(octopus_topic, message)
res = future.get(timeout=5)
self.request_producer.flush()
return res
except KafkaError as e:
print(f"Error producing message: {e}")

def _send_request(self, message: str, topic: str):
self.connect_request_producer()
queue = f"{self.prefix}_requests"
event = {"message": message, "topic": topic}
self._publish_event(event, queue)

def _get_message(
self,
consumer,
timeout: float = None,
):
if timeout is None:
timeout = 0
timeout *= 1000 # to ms
assert consumer, "should be initialized"

try:
while True: # blocks indefinitely
events = consumer.poll(timeout)
for tp, es in events.items():
ts, event = es[0].timestamp, es[0].value
if ts < self.discard_events_before:
continue

return event

except KafkaError as e:
print(f"Error consuming message: {e}")
raise TimeoutException()

def _get_request(self, timeout: float = None) -> Tuple[str, str]:
self.connect_request_consumer()

event = self._get_message(self.request_consumer, timeout)
if event["message"].endswith("null"):
raise KillSignalException()

topic, message = event["topic"], event["message"]
return topic, message

def _send_result(self, message: str, topic: str):
self.connect_request_producer()
queue = f"{self.prefix}_{topic}_result"
self._publish_event(message, queue)

def _get_result(self, topic: str, timeout: int = None) -> str:
self.connect_result_consumer(topic)
consumer = self.result_consumers.get(topic)
if not consumer:
raise ConnectionError(
f"No consumer connected for topic '{topic}'. Did you call 'connect_result_consumer('{topic}')'?"
)

event = self._get_message(consumer, timeout)
return event


if __name__ == "__main__":
queues = OctopusQueues(
topics=["generation", "lammps", "cp2k", "training", "assembly"]
)
print(queues)
print(queues.topics)

queues.connect_request_producer()
queues.connect_request_consumer()
for topic in queues.topics:
queues.connect_result_consumer(topic)

queues_dumped = pickle.dumps(queues)
print(queues_dumped)

queues_loaded = pickle.loads(queues_dumped)
print(queues_loaded.request_producer)
print(queues_loaded.request_consumer)
print(queues_loaded.result_consumers)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ dependencies = [
"gpustat",

# Data management
"pymongo>=4"
"pymongo>=4",

# Mini-app
"diaspora-event-sdk[kafka-python]"
]

[tool.setuptools.packages.find]
Expand Down
6 changes: 2 additions & 4 deletions run_parallel_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from colmena.models import Result
from colmena.task_server.parsl import ParslTaskServer
from colmena.queue import ColmenaQueues
from colmena.queue.redis import RedisQueues
from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent

from mofa.assembly.assemble import assemble_many
Expand All @@ -47,6 +46,7 @@
from mofa.hpc.colmena import DiffLinkerInference
from mofa import db as mofadb
from mofa.hpc.config import configs as hpc_configs, HPCConfig
from mofa.octopus import OctopusQueues

RDLogger.DisableLog('rdApp.*')
ob.obErrorLog.SetOutputLevel(0)
Expand Down Expand Up @@ -627,9 +627,7 @@ def store_cp2k(self, result: Result):
store = Store(name='redis', connector=RedisConnector(hostname=args.redis_host, port=6379), metrics=True)
register_store(store)

# Configure to a use Redis queue, which allows streaming results form other nodes
queues = RedisQueues(
hostname=args.redis_host,
queues = OctopusQueues(
topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'],
proxystore_name='redis',
proxystore_threshold=args.proxy_threshold
Expand Down

0 comments on commit aedc92e

Please sign in to comment.