From 80deedbceaa07ddca43fcde21cdbdf45e3cc0369 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Thu, 1 Feb 2024 16:52:10 +0800 Subject: [PATCH] fix(bindings/python): global runtime for blocking row iterator --- bindings/python/src/blocking.rs | 16 +--------------- bindings/python/src/lib.rs | 1 + bindings/python/src/types.rs | 12 ++++++------ bindings/python/src/utils.rs | 30 ++++++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 21 deletions(-) create mode 100644 bindings/python/src/utils.rs diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index e82f876c9..9075adf22 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -15,21 +15,7 @@ use pyo3::prelude::*; use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION}; - -#[ctor::ctor] -static RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - -/// Utility to collect rust futures with GIL released -fn wait_for_future(py: Python, f: F) -> F::Output -where - F: Send, - F::Output: Send, -{ - py.allow_threads(|| RUNTIME.block_on(f)) -} +use crate::utils::wait_for_future; #[pyclass(module = "databend_driver")] pub struct BlockingDatabendClient(databend_driver::Client); diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 69cdd93d6..d8f3b8e3e 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -15,6 +15,7 @@ mod asyncio; mod blocking; mod types; +mod utils; use pyo3::prelude::*; diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 0b8158123..cb367dcdf 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -22,6 +22,8 @@ use pyo3_asyncio::tokio::future_into_py; use tokio::sync::Mutex; use tokio_stream::StreamExt; +use crate::utils::wait_for_future; + pub static VERSION: Lazy = Lazy::new(|| { let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); version.to_string() @@ -136,10 +138,9 @@ impl RowIterator { #[pymethods] impl RowIterator { - fn schema<'p>(&self) -> PyResult { + fn schema<'p>(&self, py: Python) -> PyResult { let streamer = self.0.clone(); - let rt = tokio::runtime::Runtime::new()?; - let ret = rt.block_on(async move { + let ret = wait_for_future(py, async move { let schema = streamer.lock().await.schema(); schema }); @@ -149,10 +150,9 @@ impl RowIterator { fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } - fn __next__(&self) -> PyResult> { + fn __next__(&self, py: Python) -> PyResult> { let streamer = self.0.clone(); - let rt = tokio::runtime::Runtime::new()?; - let ret = rt.block_on(async move { + let ret = wait_for_future(py, async move { match streamer.lock().await.next().await { Some(val) => match val { Err(e) => Err(PyException::new_err(format!("{}", e))), diff --git a/bindings/python/src/utils.rs b/bindings/python/src/utils.rs new file mode 100644 index 000000000..6f5ee3c7d --- /dev/null +++ b/bindings/python/src/utils.rs @@ -0,0 +1,30 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pyo3::prelude::*; + +#[ctor::ctor] +pub(crate) static RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + +/// Utility to collect rust futures with GIL released +pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Send, + F::Output: Send, +{ + py.allow_threads(|| RUNTIME.block_on(f)) +}