Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made tokio initialization lazy #33

Merged
merged 1 commit into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,7 @@ fn rust_sleep(py: Python) -> PyResult<PyObject> {
#[pymodule]
fn my_async_module(py: Python, m: &PyModule) -> PyResult<()> {
pyo3_asyncio::try_init(py)?;
// Tokio needs explicit initialization before any pyo3-asyncio conversions.
// The module import is a prime place to do this.
pyo3_asyncio::tokio::init_multi_thread_once();


m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;

Ok(())
Expand Down
21 changes: 13 additions & 8 deletions pyo3-asyncio-macros/src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,23 @@ fn parse_knobs(

let config = config.build()?;

let mut rt = match config.flavor {
let builder = match config.flavor {
RuntimeFlavor::CurrentThread => quote! {
pyo3_asyncio::tokio::re_exports::runtime::Builder::new_current_thread()
},
RuntimeFlavor::Threaded => quote! {
pyo3_asyncio::tokio::re_exports::runtime::Builder::new_multi_thread()
},
};

let mut builder_init = quote! {
builder.enable_all();
};
if let Some(v) = config.worker_threads {
rt = quote! { #rt.worker_threads(#v) };
builder_init = quote! {
builder.worker_threads(#v);
#builder_init;
};
}

let rt_init = match config.flavor {
Expand All @@ -247,12 +254,10 @@ fn parse_knobs(
#body
}

pyo3_asyncio::tokio::init(
#rt
.enable_all()
.build()
.unwrap()
);
let mut builder = #builder;
#builder_init;

pyo3_asyncio::tokio::init(builder);

#rt_init

Expand Down
8 changes: 7 additions & 1 deletion pytests/test_tokio_current_thread_run_forever.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
mod tokio_run_forever;

fn main() {
pyo3_asyncio::tokio::init_current_thread();
let mut builder = tokio::runtime::Builder::new_current_thread();
builder.enable_all();

pyo3_asyncio::tokio::init(builder);
std::thread::spawn(move || {
pyo3_asyncio::tokio::get_runtime().block_on(std::future::pending::<()>());
});

tokio_run_forever::test_main();
}
2 changes: 0 additions & 2 deletions pytests/test_tokio_multi_thread_run_forever.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
mod tokio_run_forever;

fn main() {
pyo3_asyncio::tokio::init_multi_thread();

tokio_run_forever::test_main();
}
10 changes: 0 additions & 10 deletions pytests/tokio_asyncio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ fn test_init_twice() -> PyResult<()> {
common::test_init_twice()
}

#[pyo3_asyncio::tokio::test]
fn test_init_tokio_twice() -> PyResult<()> {
// tokio has already been initialized in test main. call these functions to
// make sure they don't cause problems with the other tests.
pyo3_asyncio::tokio::init_multi_thread_once();
pyo3_asyncio::tokio::init_current_thread_once();

Ok(())
}

#[pyo3_asyncio::tokio::test]
fn test_local_set_coroutine() -> PyResult<()> {
tokio::task::LocalSet::new().block_on(pyo3_asyncio::tokio::get_runtime(), async {
Expand Down
96 changes: 17 additions & 79 deletions src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{future::Future, thread};
use std::{future::Future, sync::Mutex};

use ::tokio::{
runtime::{Builder, Runtime},
task,
};
use futures::future::pending;
use once_cell::sync::OnceCell;
use once_cell::sync::{Lazy, OnceCell};
use pyo3::prelude::*;

use crate::generic;
Expand All @@ -30,10 +29,9 @@ pub use pyo3_asyncio_macros::tokio_main as main;
#[cfg(all(feature = "attributes", feature = "testing"))]
pub use pyo3_asyncio_macros::tokio_test as test;

static TOKIO_BUILDER: Lazy<Mutex<Builder>> = Lazy::new(|| Mutex::new(multi_thread()));
static TOKIO_RUNTIME: OnceCell<Runtime> = OnceCell::new();

const EXPECT_TOKIO_INIT: &str = "Tokio runtime must be initialized";

impl generic::JoinError for task::JoinError {
fn is_panic(&self) -> bool {
task::JoinError::is_panic(self)
Expand Down Expand Up @@ -65,79 +63,26 @@ impl generic::SpawnLocalExt for TokioRuntime {
}
}

/// Initialize the Tokio Runtime with a custom build
pub fn init(runtime: Runtime) {
TOKIO_RUNTIME
.set(runtime)
.expect("Tokio Runtime has already been initialized");
}

fn current_thread() -> Runtime {
Builder::new_current_thread()
.enable_all()
.build()
.expect("Couldn't build the current-thread Tokio runtime")
}

fn start_current_thread() {
thread::spawn(move || {
TOKIO_RUNTIME.get().unwrap().block_on(pending::<()>());
});
}

/// Initialize the Tokio Runtime with current-thread scheduler
///
/// # Panics
/// This function will panic if called a second time. See [`init_current_thread_once`] if you want
/// to avoid this panic.
pub fn init_current_thread() {
init(current_thread());
start_current_thread();
/// Initialize the Tokio runtime with a custom build
pub fn init(builder: Builder) {
*TOKIO_BUILDER.lock().unwrap() = builder
}

/// Get a reference to the current tokio runtime
pub fn get_runtime<'a>() -> &'a Runtime {
TOKIO_RUNTIME.get().expect(EXPECT_TOKIO_INIT)
}

fn multi_thread() -> Runtime {
Builder::new_multi_thread()
.enable_all()
.build()
.expect("Couldn't build the multi-thread Tokio runtime")
}

/// Initialize the Tokio Runtime with the multi-thread scheduler
///
/// # Panics
/// This function will panic if called a second time. See [`init_multi_thread_once`] if you want to
/// avoid this panic.
pub fn init_multi_thread() {
init(multi_thread());
}

/// Ensure that the Tokio Runtime is initialized
///
/// If the runtime has not been initialized already, the multi-thread scheduler
/// is used. Calling this function a second time is a no-op.
pub fn init_multi_thread_once() {
TOKIO_RUNTIME.get_or_init(|| multi_thread());
}

/// Ensure that the Tokio Runtime is initialized
///
/// If the runtime has not been initialized already, the current-thread
/// scheduler is used. Calling this function a second time is a no-op.
pub fn init_current_thread_once() {
let mut initialized = false;
TOKIO_RUNTIME.get_or_init(|| {
initialized = true;
current_thread()
});
TOKIO_BUILDER
.lock()
.unwrap()
.build()
.expect("Unable to build Tokio runtime")
})
}

if initialized {
start_current_thread();
}
fn multi_thread() -> Builder {
let mut builder = Builder::new_multi_thread();
builder.enable_all();
builder
}

/// Run the event loop until the given Future completes
Expand All @@ -157,16 +102,9 @@ pub fn init_current_thread_once() {
/// # use std::time::Duration;
/// #
/// # use pyo3::prelude::*;
/// # use tokio::runtime::{Builder, Runtime};
/// #
/// # let runtime = Builder::new_current_thread()
/// # .enable_all()
/// # .build()
/// # .expect("Couldn't build the runtime");
/// #
/// # Python::with_gil(|py| {
/// # pyo3_asyncio::with_runtime(py, || {
/// # pyo3_asyncio::tokio::init_current_thread();
/// pyo3_asyncio::tokio::run_until_complete(py, async move {
/// tokio::time::sleep(Duration::from_secs(1)).await;
/// Ok(())
Expand Down