diff --git a/dbos/_dbos.py b/dbos/_dbos.py index cc6391d9..9d29c319 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -45,7 +45,7 @@ start_workflow, workflow_wrapper, ) -from ._queue import Queue, _queue_thread +from ._queue import Queue, queue_thread from ._recovery import recover_pending_workflows, startup_recovery_thread from ._registrations import ( DEFAULT_MAX_RECOVERY_ATTEMPTS, @@ -283,6 +283,7 @@ def __init__( self.flask: Optional["Flask"] = flask self._executor_field: Optional[ThreadPoolExecutor] = None self._background_threads: List[threading.Thread] = [] + self._executor_id: str = os.environ.get("DBOS__VMID", "local") # If using FastAPI, set up middleware and lifecycle events if self.fastapi is not None: @@ -383,7 +384,7 @@ def _launch(self) -> None: evt = threading.Event() self.stop_events.append(evt) bg_queue_thread = threading.Thread( - target=_queue_thread, args=(evt, self), daemon=True + target=queue_thread, args=(evt, self), daemon=True ) bg_queue_thread.start() self._background_threads.append(bg_queue_thread) diff --git a/dbos/_queue.py b/dbos/_queue.py index 98c552fe..8ad75b47 100644 --- a/dbos/_queue.py +++ b/dbos/_queue.py @@ -51,13 +51,13 @@ def enqueue( return start_workflow(dbos, func, self.name, False, *args, **kwargs) -def _queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: +def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: while not stop_event.is_set(): if stop_event.wait(timeout=1): return for _, queue in dbos._registry.queue_info_map.items(): try: - wf_ids = dbos._sys_db.start_queued_workflows(queue) + wf_ids = dbos._sys_db.start_queued_workflows(queue, dbos._executor_id) for id in wf_ids: execute_workflow_by_id(dbos, id) except Exception: diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index 61e9536a..e555297f 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -1104,7 +1104,7 @@ def enqueue(self, workflow_id: str, queue_name: str) -> None: .on_conflict_do_nothing() ) - def start_queued_workflows(self, queue: "Queue") -> List[str]: + def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]: start_time_ms = int(time.time() * 1000) if queue.limiter is not None: limiter_period_ms = int(queue.limiter["period"] * 1000) @@ -1159,7 +1159,7 @@ def start_queued_workflows(self, queue: "Queue") -> List[str]: if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]: break - # To start a function, first set its status to PENDING + # To start a function, first set its status to PENDING and update its executor ID c.execute( SystemSchema.workflow_status.update() .where(SystemSchema.workflow_status.c.workflow_uuid == id) @@ -1167,7 +1167,10 @@ def start_queued_workflows(self, queue: "Queue") -> List[str]: SystemSchema.workflow_status.c.status == WorkflowStatusString.ENQUEUED.value ) - .values(status=WorkflowStatusString.PENDING.value) + .values( + status=WorkflowStatusString.PENDING.value, + executor_id=executor_id, + ) ) # Then give it a start time