diff --git a/README.md b/README.md index 4102f2f..a2e5cb2 100644 --- a/README.md +++ b/README.md @@ -162,10 +162,7 @@ fn rust_sleep(py: Python) -> PyResult { #[pymodule] fn my_async_module(py: Python, m: &PyModule) -> PyResult<()> { pyo3_asyncio::try_init(py)?; - // Tokio needs explicit initialization before any pyo3-asyncio conversions. - // The module import is a prime place to do this. - pyo3_asyncio::tokio::init_multi_thread_once(); - + m.add_function(wrap_pyfunction!(rust_sleep, m)?)?; Ok(()) diff --git a/pyo3-asyncio-macros/src/tokio.rs b/pyo3-asyncio-macros/src/tokio.rs index 98d711e..b30627e 100644 --- a/pyo3-asyncio-macros/src/tokio.rs +++ b/pyo3-asyncio-macros/src/tokio.rs @@ -219,7 +219,7 @@ fn parse_knobs( let config = config.build()?; - let mut rt = match config.flavor { + let builder = match config.flavor { RuntimeFlavor::CurrentThread => quote! { pyo3_asyncio::tokio::re_exports::runtime::Builder::new_current_thread() }, @@ -227,8 +227,15 @@ fn parse_knobs( pyo3_asyncio::tokio::re_exports::runtime::Builder::new_multi_thread() }, }; + + let mut builder_init = quote! { + builder.enable_all(); + }; if let Some(v) = config.worker_threads { - rt = quote! { #rt.worker_threads(#v) }; + builder_init = quote! { + builder.worker_threads(#v); + #builder_init; + }; } let rt_init = match config.flavor { @@ -247,12 +254,10 @@ fn parse_knobs( #body } - pyo3_asyncio::tokio::init( - #rt - .enable_all() - .build() - .unwrap() - ); + let mut builder = #builder; + #builder_init; + + pyo3_asyncio::tokio::init(builder); #rt_init diff --git a/pytests/test_tokio_current_thread_run_forever.rs b/pytests/test_tokio_current_thread_run_forever.rs index f2a72ad..439d506 100644 --- a/pytests/test_tokio_current_thread_run_forever.rs +++ b/pytests/test_tokio_current_thread_run_forever.rs @@ -1,7 +1,13 @@ mod tokio_run_forever; fn main() { - pyo3_asyncio::tokio::init_current_thread(); + let mut builder = tokio::runtime::Builder::new_current_thread(); + builder.enable_all(); + + pyo3_asyncio::tokio::init(builder); + std::thread::spawn(move || { + pyo3_asyncio::tokio::get_runtime().block_on(std::future::pending::<()>()); + }); tokio_run_forever::test_main(); } diff --git a/pytests/test_tokio_multi_thread_run_forever.rs b/pytests/test_tokio_multi_thread_run_forever.rs index 59d1ce3..243c1c6 100644 --- a/pytests/test_tokio_multi_thread_run_forever.rs +++ b/pytests/test_tokio_multi_thread_run_forever.rs @@ -1,7 +1,5 @@ mod tokio_run_forever; fn main() { - pyo3_asyncio::tokio::init_multi_thread(); - tokio_run_forever::test_main(); } diff --git a/pytests/tokio_asyncio/mod.rs b/pytests/tokio_asyncio/mod.rs index 8953d23..64da061 100644 --- a/pytests/tokio_asyncio/mod.rs +++ b/pytests/tokio_asyncio/mod.rs @@ -73,16 +73,6 @@ fn test_init_twice() -> PyResult<()> { common::test_init_twice() } -#[pyo3_asyncio::tokio::test] -fn test_init_tokio_twice() -> PyResult<()> { - // tokio has already been initialized in test main. call these functions to - // make sure they don't cause problems with the other tests. - pyo3_asyncio::tokio::init_multi_thread_once(); - pyo3_asyncio::tokio::init_current_thread_once(); - - Ok(()) -} - #[pyo3_asyncio::tokio::test] fn test_local_set_coroutine() -> PyResult<()> { tokio::task::LocalSet::new().block_on(pyo3_asyncio::tokio::get_runtime(), async { diff --git a/src/tokio.rs b/src/tokio.rs index c629918..d3e13bb 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -1,11 +1,10 @@ -use std::{future::Future, thread}; +use std::{future::Future, sync::Mutex}; use ::tokio::{ runtime::{Builder, Runtime}, task, }; -use futures::future::pending; -use once_cell::sync::OnceCell; +use once_cell::sync::{Lazy, OnceCell}; use pyo3::prelude::*; use crate::generic; @@ -30,10 +29,9 @@ pub use pyo3_asyncio_macros::tokio_main as main; #[cfg(all(feature = "attributes", feature = "testing"))] pub use pyo3_asyncio_macros::tokio_test as test; +static TOKIO_BUILDER: Lazy> = Lazy::new(|| Mutex::new(multi_thread())); static TOKIO_RUNTIME: OnceCell = OnceCell::new(); -const EXPECT_TOKIO_INIT: &str = "Tokio runtime must be initialized"; - impl generic::JoinError for task::JoinError { fn is_panic(&self) -> bool { task::JoinError::is_panic(self) @@ -65,79 +63,26 @@ impl generic::SpawnLocalExt for TokioRuntime { } } -/// Initialize the Tokio Runtime with a custom build -pub fn init(runtime: Runtime) { - TOKIO_RUNTIME - .set(runtime) - .expect("Tokio Runtime has already been initialized"); -} - -fn current_thread() -> Runtime { - Builder::new_current_thread() - .enable_all() - .build() - .expect("Couldn't build the current-thread Tokio runtime") -} - -fn start_current_thread() { - thread::spawn(move || { - TOKIO_RUNTIME.get().unwrap().block_on(pending::<()>()); - }); -} - -/// Initialize the Tokio Runtime with current-thread scheduler -/// -/// # Panics -/// This function will panic if called a second time. See [`init_current_thread_once`] if you want -/// to avoid this panic. -pub fn init_current_thread() { - init(current_thread()); - start_current_thread(); +/// Initialize the Tokio runtime with a custom build +pub fn init(builder: Builder) { + *TOKIO_BUILDER.lock().unwrap() = builder } /// Get a reference to the current tokio runtime pub fn get_runtime<'a>() -> &'a Runtime { - TOKIO_RUNTIME.get().expect(EXPECT_TOKIO_INIT) -} - -fn multi_thread() -> Runtime { - Builder::new_multi_thread() - .enable_all() - .build() - .expect("Couldn't build the multi-thread Tokio runtime") -} - -/// Initialize the Tokio Runtime with the multi-thread scheduler -/// -/// # Panics -/// This function will panic if called a second time. See [`init_multi_thread_once`] if you want to -/// avoid this panic. -pub fn init_multi_thread() { - init(multi_thread()); -} - -/// Ensure that the Tokio Runtime is initialized -/// -/// If the runtime has not been initialized already, the multi-thread scheduler -/// is used. Calling this function a second time is a no-op. -pub fn init_multi_thread_once() { - TOKIO_RUNTIME.get_or_init(|| multi_thread()); -} - -/// Ensure that the Tokio Runtime is initialized -/// -/// If the runtime has not been initialized already, the current-thread -/// scheduler is used. Calling this function a second time is a no-op. -pub fn init_current_thread_once() { - let mut initialized = false; TOKIO_RUNTIME.get_or_init(|| { - initialized = true; - current_thread() - }); + TOKIO_BUILDER + .lock() + .unwrap() + .build() + .expect("Unable to build Tokio runtime") + }) +} - if initialized { - start_current_thread(); - } +fn multi_thread() -> Builder { + let mut builder = Builder::new_multi_thread(); + builder.enable_all(); + builder } /// Run the event loop until the given Future completes @@ -157,16 +102,9 @@ pub fn init_current_thread_once() { /// # use std::time::Duration; /// # /// # use pyo3::prelude::*; -/// # use tokio::runtime::{Builder, Runtime}; -/// # -/// # let runtime = Builder::new_current_thread() -/// # .enable_all() -/// # .build() -/// # .expect("Couldn't build the runtime"); /// # /// # Python::with_gil(|py| { /// # pyo3_asyncio::with_runtime(py, || { -/// # pyo3_asyncio::tokio::init_current_thread(); /// pyo3_asyncio::tokio::run_until_complete(py, async move { /// tokio::time::sleep(Duration::from_secs(1)).await; /// Ok(())