From bfcaad705dd68f528b607ce4ae93582fe74b284b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 22 Jan 2025 01:08:30 +0100 Subject: [PATCH 1/3] proc-macro added for pyo3-stub-gen to generate stub pyi --- python/Cargo.toml | 7 ++++++- python/src/bin/stub_gen.rs | 8 ++++++++ python/src/egor.rs | 5 +++++ python/src/gp_mix.rs | 5 +++++ python/src/lib.rs | 4 ++++ python/src/sampling.rs | 4 ++++ python/src/sparse_gp_mix.rs | 5 +++++ python/src/types.rs | 14 ++++++++++++++ 8 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 python/src/bin/stub_gen.rs diff --git a/python/Cargo.toml b/python/Cargo.toml index b1e754bf..e19d2764 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [lib] name = "egobox" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [features] default = [] @@ -48,3 +48,8 @@ serde_json.workspace = true ctrlc.workspace = true argmin_testfunctions.workspace = true +pyo3-stub-gen = { path = "../../pyo3-stub-gen/pyo3-stub-gen", features = ["numpy"] } + +[[bin]] +name = "stub_gen" +doc = false diff --git a/python/src/bin/stub_gen.rs b/python/src/bin/stub_gen.rs new file mode 100644 index 00000000..6b8b8e0f --- /dev/null +++ b/python/src/bin/stub_gen.rs @@ -0,0 +1,8 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init(); + let stub = egobox::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/python/src/egor.rs b/python/src/egor.rs index eda12ac7..6d1ab9e0 100644 --- a/python/src/egor.rs +++ b/python/src/egor.rs @@ -16,6 +16,7 @@ use ndarray::{concatenate, Array1, Array2, ArrayView2, Axis}; use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, ToPyArray}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods}; /// Utility function converting `xlimits` float data list specifying bounds of x components /// to x specified as a list of XType.Float types [egobox.XType] @@ -25,6 +26,7 @@ use pyo3::prelude::*; /// /// # Returns /// xtypes: nx-size list of XSpec(XType(FLOAT), [lower_bound, upper_bounds]) where `nx` is the dimension of x +#[gen_stub_pyfunction] #[pyfunction] pub(crate) fn to_specs(py: Python, xlimits: Vec>) -> PyResult { if xlimits.is_empty() || xlimits[0].is_empty() { @@ -149,6 +151,7 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec>) -> PyResult /// seed (int >= 0) /// Random generator seed to allow computation reproducibility. /// +#[gen_stub_pyclass] #[pyclass] pub(crate) struct Egor { pub xspecs: PyObject, @@ -174,6 +177,7 @@ pub(crate) struct Egor { pub seed: Option, } +#[gen_stub_pyclass] #[pyclass] pub(crate) struct OptimResult { #[pyo3(get)] @@ -186,6 +190,7 @@ pub(crate) struct OptimResult { y_doe: Py>, } +#[gen_stub_pymethods] #[pymethods] impl Egor { #[new] diff --git a/python/src/gp_mix.rs b/python/src/gp_mix.rs index 2f10caf4..d0d0c992 100644 --- a/python/src/gp_mix.rs +++ b/python/src/gp_mix.rs @@ -22,6 +22,7 @@ use ndarray::{Array1, Array2, Axis, Ix1, Ix2, Zip}; use ndarray_rand::rand::SeedableRng; use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2, PyReadonlyArrayDyn}; use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use rand_xoshiro::Xoshiro256Plus; /// Gaussian processes mixture builder @@ -69,6 +70,7 @@ use rand_xoshiro::Xoshiro256Plus; /// seed (int >= 0) /// Random generator seed to allow computation reproducibility. /// +#[gen_stub_pyclass] #[pyclass] pub(crate) struct GpMix { pub n_clusters: usize, @@ -82,6 +84,7 @@ pub(crate) struct GpMix { pub seed: Option, } +#[gen_stub_pymethods] #[pymethods] impl GpMix { #[new] @@ -218,9 +221,11 @@ impl GpMix { } /// A trained Gaussian processes mixture +#[gen_stub_pyclass] #[pyclass] pub(crate) struct Gpx(Box); +#[gen_stub_pymethods] #[pymethods] impl Gpx { /// Get Gaussian processes mixture builder aka `GpMix` diff --git a/python/src/lib.rs b/python/src/lib.rs index dd7cb5b9..4f366f96 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -14,6 +14,7 @@ use types::*; use env_logger::{Builder, Env}; use pyo3::prelude::*; +use pyo3_stub_gen::define_stub_info_gatherer; #[doc(hidden)] #[pymodule] @@ -55,3 +56,6 @@ fn egobox(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +// Define a function to gather stub information. +define_stub_info_gatherer!(stub_info); diff --git a/python/src/sampling.rs b/python/src/sampling.rs index 0be59a43..c5168875 100644 --- a/python/src/sampling.rs +++ b/python/src/sampling.rs @@ -3,7 +3,9 @@ use egobox_doe::{LhsKind, SamplingMethod}; use egobox_ego::gpmix::mixint::MixintContext; use numpy::{IntoPyArray, PyArray2}; use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass_enum, gen_stub_pyfunction}; +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "SCREAMING_SNAKE_CASE")] #[derive(Debug, Clone, Copy, PartialEq)] pub enum Sampling { @@ -27,6 +29,7 @@ pub enum Sampling { /// # Returns /// ndarray of shape (n_samples, n_variables) /// +#[gen_stub_pyfunction] #[pyfunction] #[pyo3(signature = (method, xspecs, n_samples, seed=None))] pub fn sampling( @@ -89,6 +92,7 @@ pub fn sampling( /// # Returns /// ndarray of shape (n_samples, n_variables) /// +#[gen_stub_pyfunction] #[pyfunction] #[pyo3(signature = (xspecs, n_samples, seed=None))] pub(crate) fn lhs( diff --git a/python/src/sparse_gp_mix.rs b/python/src/sparse_gp_mix.rs index 28bfe3c3..c0990ec4 100644 --- a/python/src/sparse_gp_mix.rs +++ b/python/src/sparse_gp_mix.rs @@ -21,6 +21,7 @@ use ndarray::{Array1, Array2, Axis, Ix1, Ix2, Zip}; use ndarray_rand::rand::SeedableRng; use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2}; use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use rand_xoshiro::Xoshiro256Plus; /// Sparse Gaussian processes mixture builder @@ -58,6 +59,7 @@ use rand_xoshiro::Xoshiro256Plus; /// seed (int >= 0) /// Random generator seed to allow computation reproducibility. /// +#[gen_stub_pyclass] #[pyclass] pub(crate) struct SparseGpMix { pub correlation_spec: CorrelationSpec, @@ -71,6 +73,7 @@ pub(crate) struct SparseGpMix { pub seed: Option, } +#[gen_stub_pymethods] #[pymethods] impl SparseGpMix { #[new] @@ -216,9 +219,11 @@ impl SparseGpMix { } /// A trained Gaussian processes mixture +#[gen_stub_pyclass] #[pyclass] pub(crate) struct SparseGpx(Box); +#[gen_stub_pymethods] #[pymethods] impl SparseGpx { /// Get Gaussian processes mixture builder aka `GpSparse` diff --git a/python/src/types.rs b/python/src/types.rs index 7aad25a8..80ce6c1a 100644 --- a/python/src/types.rs +++ b/python/src/types.rs @@ -1,5 +1,7 @@ use pyo3::prelude::*; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods}; +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Debug, Clone, PartialEq)] pub enum Recombination { @@ -12,10 +14,12 @@ pub enum Recombination { Smooth = 1, } +#[gen_stub_pyclass] #[pyclass] #[derive(Clone)] pub(crate) struct RegressionSpec(pub(crate) u8); +#[gen_stub_pymethods] #[pymethods] impl RegressionSpec { #[classattr] @@ -28,10 +32,12 @@ impl RegressionSpec { pub(crate) const QUADRATIC: u8 = egobox_moe::RegressionSpec::QUADRATIC.bits(); } +#[gen_stub_pyclass] #[pyclass] #[derive(Clone)] pub(crate) struct CorrelationSpec(pub(crate) u8); +#[gen_stub_pymethods] #[pymethods] impl CorrelationSpec { #[classattr] @@ -48,6 +54,7 @@ impl CorrelationSpec { pub(crate) const MATERN52: u8 = egobox_moe::CorrelationSpec::MATERN52.bits(); } +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum InfillStrategy { @@ -56,6 +63,7 @@ pub(crate) enum InfillStrategy { Wb2s = 3, } +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum ParInfillStrategy { @@ -65,6 +73,7 @@ pub(crate) enum ParInfillStrategy { Clmin = 4, } +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum InfillOptimizer { @@ -72,6 +81,7 @@ pub(crate) enum InfillOptimizer { Slsqp = 2, } +#[gen_stub_pyclass] #[pyclass] #[derive(Clone, Copy)] pub(crate) struct ExpectedOptimum { @@ -93,6 +103,7 @@ impl ExpectedOptimum { } } +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Clone, Copy, Debug, PartialEq)] pub(crate) enum XType { @@ -102,6 +113,7 @@ pub(crate) enum XType { Enum = 4, } +#[gen_stub_pyclass] #[pyclass] #[derive(FromPyObject, Debug)] pub(crate) struct XSpec { @@ -113,6 +125,7 @@ pub(crate) struct XSpec { pub(crate) tags: Vec, } +#[gen_stub_pymethods] #[pymethods] impl XSpec { #[new] @@ -126,6 +139,7 @@ impl XSpec { } } +#[gen_stub_pyclass_enum] #[pyclass(eq, eq_int, rename_all = "UPPERCASE")] #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum SparseMethod { From 84be6c2e0e29dd21df8fe82258ad47f8a9645532 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 22 Jan 2025 01:09:03 +0100 Subject: [PATCH 2/3] adapt python source folder --- python/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index e8b58bfd..180bafef 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,8 @@ build-backend = "maturin" requires = ["maturin>=1.0, <2.0"] [tool.maturin] -python-source = "./" +python-source = "egobox" +module-name = "egobox" features = ["pyo3/extension-module"] # Optional usage of BLAS backend # cargo-extra-args = "--features linfa/intel-mkl-static" From 327a42f42f40b5c72ba92f73cea3325e53786f3a Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 22 Jan 2025 01:09:42 +0100 Subject: [PATCH 3/3] Cargo.lock updated --- Cargo.lock | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 189 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0be81951..a0121801 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anes" version = "0.1.6" @@ -242,6 +257,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -336,6 +365,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.16" @@ -366,7 +401,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -387,7 +422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -556,6 +591,7 @@ dependencies = [ "numpy", "pyo3", "pyo3-log", + "pyo3-stub-gen", "rand_xoshiro", "serde", "serde_json", @@ -710,6 +746,12 @@ dependencies = [ "log", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "erased-serde" version = "0.4.5" @@ -853,6 +895,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" + [[package]] name = "heck" version = "0.5.0" @@ -871,6 +919,29 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -878,7 +949,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +dependencies = [ + "equivalent", + "hashbrown 0.15.2", ] [[package]] @@ -931,6 +1012,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -1114,6 +1204,12 @@ version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -1225,8 +1321,8 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af5a8477ac96877b5bd1fd67e0c28736c12943aba24eda92b127e036b0c8f400" dependencies = [ - "indexmap", - "itertools", + "indexmap 1.9.3", + "itertools 0.10.5", "ndarray", "noisy_float", "num-integer", @@ -1622,6 +1718,33 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "pyo3-stub-gen" +version = "0.6.2" +dependencies = [ + "anyhow", + "chrono", + "inventory", + "itertools 0.13.0", + "log", + "maplit", + "num-complex", + "numpy", + "pyo3", + "pyo3-stub-gen-derive", + "serde", + "toml", +] + +[[package]] +name = "pyo3-stub-gen-derive" +version = "0.6.2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "quote" version = "1.0.38" @@ -1836,6 +1959,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serial_test" version = "3.2.0" @@ -2093,6 +2225,40 @@ dependencies = [ "serde_json", ] +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap 2.7.1", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "typeid" version = "1.0.2" @@ -2284,6 +2450,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2366,6 +2541,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.6.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +dependencies = [ + "memchr", +] + [[package]] name = "zerocopy" version = "0.7.35"