Skip to content

Commit

Permalink
fix: python async (#309)
Browse files Browse the repository at this point in the history
* fix: python async

* fix python async
  • Loading branch information
stefan-gorules authored Feb 8, 2025
1 parent 9ca7ee0 commit 55a6320
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 68 deletions.
4 changes: 2 additions & 2 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ crate-type = ["cdylib"]
anyhow = { workspace = true }
either = "1.13"
pyo3 = { version = "0.23", features = ["anyhow", "serde"] }
pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime"] }
pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime", "attributes"] }
pythonize = "0.23"
json_dotpath = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
futures = "0.3"
tokio-util = { version = "0.7", features = ["rt"] }
zen-engine = { path = "../../core/engine" }
zen-expression = { path = "../../core/expression" }
zen-tmpl = { path = "../../core/template" }
34 changes: 21 additions & 13 deletions bindings/python/src/custom_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::anyhow;
use either::Either;
use pyo3::types::PyDict;
use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyObject, PyResult, Python};
use pyo3_async_runtimes::tokio;
use pyo3_async_runtimes::TaskLocals;
use pythonize::depythonize;

use zen_engine::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
Expand All @@ -11,17 +11,17 @@ use zen_engine::handler::node::{NodeResponse, NodeResult};
use crate::types::PyNodeRequest;

#[derive(Default)]
pub(crate) struct PyCustomNode(Option<Py<PyAny>>);

impl From<Py<PyAny>> for PyCustomNode {
fn from(value: Py<PyAny>) -> Self {
Self(Some(value))
}
pub(crate) struct PyCustomNode {
callback: Option<Py<PyAny>>,
task_locals: Option<TaskLocals>,
}

impl From<Option<PyObject>> for PyCustomNode {
fn from(value: Option<PyObject>) -> Self {
Self(value)
impl PyCustomNode {
pub fn new(callback: Option<Py<PyAny>>, task_locals: Option<TaskLocals>) -> Self {
Self {
callback,
task_locals,
}
}
}

Expand All @@ -33,7 +33,7 @@ fn extract_custom_node_response(py: Python<'_>, result: PyObject) -> NodeResult

impl CustomNodeAdapter for PyCustomNode {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
let Some(callable) = &self.0 else {
let Some(callable) = &self.callback else {
return Err(anyhow!("Custom node handler not provided"));
};

Expand All @@ -45,8 +45,16 @@ impl CustomNodeAdapter for PyCustomNode {
return Ok(Either::Left(extract_custom_node_response(py, result)));
}

let result_future = tokio::into_future(result.into_bound_py_any(py)?)?;
return Ok(Either::Right(result_future));
let Some(task_locals) = &self.task_locals else {
Err(anyhow!("Task locals are required in async context"))?
};

let result_future = pyo3_async_runtimes::into_future_with_locals(
task_locals,
result.into_bound_py_any(py)?,
)?;

Ok(Either::Right(result_future))
});

match maybe_result? {
Expand Down
59 changes: 36 additions & 23 deletions bindings/python/src/decision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ use anyhow::{anyhow, Context};
use pyo3::types::PyDict;
use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python};
use pyo3_async_runtimes::tokio;
use pyo3_async_runtimes::tokio::get_current_locals;
use pyo3_async_runtimes::tokio::re_exports::runtime::Runtime;
use pythonize::depythonize;
use serde_json::Value;
use zen_engine::{Decision, EvaluationOptions};

use crate::custom_node::PyCustomNode;
use crate::engine::PyZenEvaluateOptions;
use crate::loader::PyDecisionLoader;
use crate::mt::worker_pool;
use crate::value::PyValue;

#[pyclass]
Expand Down Expand Up @@ -40,16 +43,19 @@ impl PyZenDecision {
};

let decision = self.0.clone();
let result = futures::executor::block_on(decision.evaluate_with_opts(
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
))
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?;

let rt = Runtime::new()?;
let result = rt
.block_on(decision.evaluate_with_opts(
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
))
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?;

let value = serde_json::to_value(&result).context("Fail")?;
PyValue(value).into_py_any(py)
Expand All @@ -70,19 +76,26 @@ impl PyZenDecision {
};

let decision = self.0.clone();
let result = tokio::future_into_py(py, async move {
let result = futures::executor::block_on(decision.evaluate_with_opts(
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
))
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?;

let value = serde_json::to_value(result).context("Failed to serialize result")?;
let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move {
let value = worker_pool()
.spawn_pinned(move || async move {
decision
.evaluate_with_opts(
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
)
.await
.map(serde_json::to_value)
})
.await
.context("Failed to join threads")?
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?
.context("Failed to serialize result")?;

Python::with_gil(|py| PyValue(value).into_py_any(py))
})?;
Expand Down
69 changes: 40 additions & 29 deletions bindings/python/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use anyhow::{anyhow, Context};
use pyo3::prelude::PyDictMethods;
use pyo3::types::PyDict;
use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python};
use pyo3_async_runtimes::tokio;
use pyo3_async_runtimes::tokio::get_current_locals;
use pyo3_async_runtimes::{tokio, TaskLocals};
use pythonize::depythonize;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand All @@ -14,12 +15,13 @@ use zen_engine::{DecisionEngine, EvaluationOptions};
use crate::custom_node::PyCustomNode;
use crate::decision::PyZenDecision;
use crate::loader::PyDecisionLoader;
use crate::mt::{block_on, worker_pool};
use crate::value::PyValue;

#[pyclass]
#[pyo3(name = "ZenEngine")]
pub struct PyZenEngine {
graph: Arc<DecisionEngine<PyDecisionLoader, PyCustomNode>>,
engine: Arc<DecisionEngine<PyDecisionLoader, PyCustomNode>>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -40,11 +42,10 @@ impl Default for PyZenEvaluateOptions {
impl Default for PyZenEngine {
fn default() -> Self {
Self {
graph: DecisionEngine::new(
engine: Arc::new(DecisionEngine::new(
Arc::new(PyDecisionLoader::default()),
Arc::new(PyCustomNode::default()),
)
.into(),
)),
}
}
}
Expand All @@ -53,7 +54,7 @@ impl Default for PyZenEngine {
impl PyZenEngine {
#[new]
#[pyo3(signature = (maybe_options=None))]
pub fn new(maybe_options: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
pub fn new(py: Python, maybe_options: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(options) = maybe_options else {
return Ok(Default::default());
};
Expand All @@ -68,12 +69,16 @@ impl PyZenEngine {
None => None,
};

let task_locals = TaskLocals::with_running_loop(py)
.ok()
.map(|s| s.copy_context(py).ok())
.flatten();

Ok(Self {
graph: DecisionEngine::new(
engine: Arc::new(DecisionEngine::new(
Arc::new(PyDecisionLoader::from(loader)),
Arc::new(PyCustomNode::from(custom_node)),
)
.into(),
Arc::new(PyCustomNode::new(custom_node, task_locals)),
)),
})
}

Expand All @@ -92,8 +97,7 @@ impl PyZenEngine {
Default::default()
};

let graph = self.graph.clone();
let result = futures::executor::block_on(graph.evaluate_with_opts(
let result = block_on(self.engine.evaluate_with_opts(
key,
context.into(),
EvaluationOptions {
Expand Down Expand Up @@ -124,21 +128,28 @@ impl PyZenEngine {
Default::default()
};

let graph = self.graph.clone();
let result = tokio::future_into_py(py, async move {
let result = futures::executor::block_on(graph.evaluate_with_opts(
key,
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
))
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?;

let value = serde_json::to_value(result).context("Failed to serialize result")?;
let engine = self.engine.clone();
let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move {
let value = worker_pool()
.spawn_pinned(move || async move {
engine
.evaluate_with_opts(
key,
context.into(),
EvaluationOptions {
max_depth: options.max_depth,
trace: options.trace,
},
)
.await
.map(serde_json::to_value)
})
.await
.context("Failed to join threads")?
.map_err(|e| {
anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()))
})?
.context("Failed to serialize result")?;

Python::with_gil(|py| PyValue(value).into_py_any(py))
})?;
Expand All @@ -150,12 +161,12 @@ impl PyZenEngine {
let decision_content: DecisionContent =
serde_json::from_str(&content).context("Failed to serialize decision content")?;

let decision = self.graph.create_decision(decision_content.into());
let decision = self.engine.create_decision(decision_content.into());
Ok(PyZenDecision::from(decision))
}

pub fn get_decision<'py>(&'py self, _py: Python<'py>, key: String) -> PyResult<PyZenDecision> {
let decision = futures::executor::block_on(self.graph.get_decision(&key))
let decision = block_on(self.engine.get_decision(&key))
.context("Failed to find decision with given key")?;

Ok(PyZenDecision::from(decision))
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod decision;
mod engine;
mod expression;
mod loader;
mod mt;
mod types;
mod value;

Expand Down
25 changes: 25 additions & 0 deletions bindings/python/src/mt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use pyo3_async_runtimes::tokio::re_exports::runtime::Runtime;
use std::sync::OnceLock;
use std::thread::available_parallelism;
use tokio_util::task::LocalPoolHandle;

fn parallelism() -> usize {
available_parallelism().map(Into::into).unwrap_or(1)
}

pub(crate) fn worker_pool() -> LocalPoolHandle {
static LOCAL_POOL: OnceLock<LocalPoolHandle> = OnceLock::new();
LOCAL_POOL
.get_or_init(|| LocalPoolHandle::new(parallelism()))
.clone()
}

static RUNTIME: OnceLock<Runtime> = OnceLock::new();

fn get_runtime() -> &'static Runtime {
RUNTIME.get_or_init(|| Runtime::new().unwrap())
}

pub(crate) fn block_on<F: std::future::Future>(future: F) -> F::Output {
get_runtime().block_on(future)
}
15 changes: 14 additions & 1 deletion bindings/python/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import json
import os.path
import time
import unittest

import zen
Expand All @@ -26,7 +27,7 @@ def custom_handler(request):

async def custom_async_handler(request):
p1 = request.get_field("prop1")
await asyncio.sleep(0.25)
await asyncio.sleep(0.1)
return {
"output": {"sum": p1}
}
Expand Down Expand Up @@ -55,6 +56,18 @@ async def test_async_evaluate_custom_handler(self):
self.assertEqual(results[1]["result"]["sum"], 30)
self.assertEqual(results[2]["result"]["sum"], 40)

async def test_async_sleep_function(self):
engine = zen.ZenEngine({"loader": loader, "customHandler": custom_async_handler})

await engine.async_evaluate("sleep-function.json", {})
self.assertTrue(True)

async def test_async_http_function(self):
engine = zen.ZenEngine({"loader": loader, "customHandler": custom_async_handler})

await engine.async_evaluate("http-function.json", {})
self.assertTrue(True)

async def test_create_decisions_from_content(self):
engine = zen.ZenEngine()
with open("../../test-data/function.json", "r") as f:
Expand Down
12 changes: 12 additions & 0 deletions bindings/python/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ def test_render_template(self):
result = zen.render_template("{{ a + b }}", {"a": 10, "b": 20})
self.assertEqual(result, 30)

def test_sleep_function(self):
engine = zen.ZenEngine({"loader": loader, "customHandler": custom_handler})

engine.evaluate("sleep-function.json", {})
self.assertTrue(True)

def test_http_function(self):
engine = zen.ZenEngine({"loader": loader, "customHandler": custom_handler})

engine.evaluate("http-function.json", {})
self.assertTrue(True)

def test_evaluate_graphs(self):
engine = zen.ZenEngine({"loader": graph_loader})
json_files = glob.glob("../../test-data/graphs/*.json")
Expand Down
Loading

0 comments on commit 55a6320

Please sign in to comment.