Skip to content

Commit

Permalink
to debug
Browse files Browse the repository at this point in the history
  • Loading branch information
haochenpan committed Nov 26, 2024
1 parent 23f1db3 commit 6aa4cf5
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 24 deletions.
5 changes: 5 additions & 0 deletions envs/environment-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ dependencies:
- pip:
- git+https://gitlab.com/ase/ase.git
- -e ..[test]
- aws-msk-iam-sasl-signer-python
- confluent_kafka
- proxystore[all]
- iterators
- setuptools
115 changes: 98 additions & 17 deletions mofa/proxystream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import logging
import os
import pickle
import threading
from datetime import datetime
from time import time
from typing import Collection, Dict, Optional, Tuple, Union
import threading
from iterators import TimeoutIterator

from aws_msk_iam_sasl_signer import MSKAuthTokenProvider
from colmena.exceptions import KillSignalException, TimeoutException
Expand All @@ -21,10 +22,23 @@

from proxystore.stream.shims.kafka import KafkaPublisher, KafkaSubscriber

logger = logging.getLogger(__name__)

assert os.environ["OCTOPUS_AWS_ACCESS_KEY_ID"]
assert os.environ["OCTOPUS_AWS_SECRET_ACCESS_KEY"]
assert os.environ["OCTOPUS_BOOTSTRAP_SERVERS"]

assert os.environ["PROXYSTORE_GLOBUS_CLIENT_ID"]
assert os.environ["PROXYSTORE_GLOBUS_CLIENT_SECRET"]

assert os.environ["PROXYSTORE_ENDPOINT"]

print(os.environ["PROXYSTORE_ENDPOINT"])

class ProxyQueues(ColmenaQueues):
def __init__(
self,
store,
topics: Collection[str],
prefix: str = "mofa_test2",
auto_offset_reset: str = "earliest",
Expand All @@ -36,15 +50,7 @@ def __init__(
proxystore_name: Optional[Union[str, Dict[str, str]]] = None,
proxystore_threshold: Optional[Union[int, Dict[str, int]]] = None,
):
assert os.environ["OCTOPUS_AWS_ACCESS_KEY_ID"]
assert os.environ["OCTOPUS_AWS_SECRET_ACCESS_KEY"]
assert os.environ["OCTOPUS_BOOTSTRAP_SERVERS"]

self.endpointConnector = EndpointConnector(
["86b11712-50d7-4a08-a7cd-d316d7c50080"]
)
self.store = Store(proxystore_name, connector=self.endpointConnector)
register_store(self.store)

super().__init__(
topics,
Expand All @@ -54,6 +60,7 @@ def __init__(
proxystore_threshold,
)
# self.topics in handled in super
self.store = store
self.prefix = prefix
self.auto_offset_reset = auto_offset_reset
self.discard_events_before = discard_events_before
Expand Down Expand Up @@ -86,7 +93,7 @@ def oauth_cb(oauth_config):
def connect_request_producer(self):
"""Connect the request producer."""
if not isinstance(self.request_producer, StreamProducer):
conf = self.octopus_conf("my-group2", self.auto_offset_reset)
conf = self.octopus_conf("my-group", self.auto_offset_reset)
producer = Producer(conf)
publisher = KafkaPublisher(client=producer)

Expand All @@ -104,7 +111,7 @@ def connect_request_producer(self):
def connect_request_consumer(self):
"""Connect the request consumer."""
if not isinstance(self.request_consumer, StreamConsumer):
conf = self.octopus_conf("my-group2", self.auto_offset_reset)
conf = self.octopus_conf("my-group", self.auto_offset_reset)
consumer = Consumer(conf)
request_topic = f"{self.prefix}_requests"
consumer.subscribe([request_topic])
Expand Down Expand Up @@ -187,29 +194,103 @@ def _get_message(
):
if timeout is None:
timeout = 0
timeout *= 1000 # to ms
# timeout *= 1000 # to ms
assert consumer, "consumer should be initialized"

if not timeout or timeout == 0:
print("here1")
while True:
event = consumer.next_object()
if event: # gives None if there is a timeout
print("event here:", event)
return event

else:
print("here2")
consumer_iter = consumer.iter_objects()
consumer_iter = TimeoutIterator(consumer_iter, timeout=timeout)

try:
event = consumer.next_object()
print(event)
return event
while True: # blocks indefinitely
for event in consumer_iter:
if event: # gives None if there is a timeout
print("event here:", event)
return event

# event = consumer.next_object()
# return event

except Exception as e:
print(f"Error consuming message: {e}")
print(f"Error consuming message: {e}, {consumer}, {timeout}")
raise TimeoutException()


def _get_request(self, timeout: float = None) -> Tuple[str, str]:
self.connect_request_consumer()

event = self._get_message(self.request_consumer, timeout)
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.warning(f"ProxyQueues::request event:: {current_time}, event={event}")
if event["message"].endswith("null"):
raise KillSignalException()

topic, message = event["topic"], event["message"]
return topic, message

def _send_result(self, message: str, topic: str):
self.connect_request_producer()
queue = f"{self.prefix}_{topic}_result"
self._publish_event(message, queue)

def _get_result(self, topic: str, timeout: int = None) -> str:
self.connect_result_consumer(topic)
consumer = self.result_consumers.get(topic)
if not consumer:
raise ConnectionError(
f"No consumer connected for topic '{topic}'. Did you call 'connect_result_consumer('{topic}')'?"
)

event = self._get_message(consumer, timeout)
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.warning(f"ProxyQueues::result event:: {current_time}, event={event}")
return event


if __name__ == "__main__":
# print(os.environ["PROXYSTORE_ENDPOINT"])

# endpoint_connector = EndpointConnector([os.environ["PROXYSTORE_ENDPOINT"]])

endpoints = [
"f4b4290b-8d3f-413e-882e-a2932213ade2",
"074b70c1-c85e-4e18-af86-70b443dfac0f"
]
endpoint_connector = EndpointConnector(endpoints)

store = Store("my-endpoint2", connector=endpoint_connector)
register_store(store)

queues = ProxyQueues(
store=store,
topics=["generation", "lammps", "cp2k", "training", "assembly"],
proxystore_name="my-endpoint-store",
proxystore_name="my-endpoint2",
)
print(queues)
print(queues.topics)

queues.connect_request_producer()
queues.connect_request_consumer()
# for topic in queues.topics:
# queues.connect_result_consumer(topic)

# queues_dumped = pickle.dumps(queues)
# print(queues_dumped)

# queues_loaded = pickle.loads(queues_dumped)
# print(queues_loaded.request_producer)
# print(queues_loaded.request_consumer)
# print(queues_loaded.result_consumers)

queues._send_request("123", "generation")
queues._get_message(queues.request_consumer, 0)
# print("here")
26 changes: 19 additions & 7 deletions run_parallel_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import hashlib
import json
import sys
import os

import pymongo
from proxystore.connectors.endpoint import EndpointConnector
from proxystore.connectors.redis import RedisConnector
from proxystore.store import Store, register_store
from rdkit import RDLogger
Expand All @@ -47,6 +49,7 @@
from mofa import db as mofadb
from mofa.hpc.config import configs as hpc_configs, HPCConfig
from mofa.octopus import OctopusQueues
from mofa.proxystream import ProxyQueues

RDLogger.DisableLog('rdApp.*')
ob.obErrorLog.SetOutputLevel(0)
Expand Down Expand Up @@ -627,14 +630,23 @@ def store_cp2k(self, result: Result):
run_dir.mkdir(parents=True)

# Open a proxystore with Redis
store = Store(name='redis', connector=RedisConnector(hostname=args.redis_host, port=6379), metrics=True)

endpoint_connector = EndpointConnector([os.environ["PROXYSTORE_ENDPOINT"]])
store = Store("my-endpoint", connector=endpoint_connector)
register_store(store)

queues = OctopusQueues(
topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'],
# proxystore_name='redis',
# proxystore_threshold=args.proxy_threshold
queues = ProxyQueues(
store=store,
topics=["generation", "lammps", "cp2k", "training", "assembly"],
proxystore_name="my-endpoint",
)

# store = Store(name='redis', connector=RedisConnector(hostname=args.redis_host, port=6379), metrics=True)
# register_store(store)
# queues = OctopusQueues(
# topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'],
# # proxystore_name='redis',
# # proxystore_threshold=args.proxy_threshold
# )

# Load the ligand descriptions
templates = []
Expand Down Expand Up @@ -737,7 +749,7 @@ def store_cp2k(self, result: Result):
my_logger.info(f"Octopus::launch_option={args.launch_option}")
my_logger.info(f"Octopus::prefix={queues.prefix}")
my_logger.info(f"Octopus::discard={queues.discard_events_before}")
my_logger.info(f"Octopus::redis={args.redis_host}")
# my_logger.info(f"Octopus::redis={args.redis_host}")

# Save the run parameters to disk
(run_dir / 'params.json').write_text(json.dumps(run_params))
Expand Down

0 comments on commit 6aa4cf5

Please sign in to comment.