Skip to content

Commit

Permalink
minimal upgrade to PyO3 0.23 (ignoring deprecations)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Nov 13, 2024
1 parent cd0346d commit cc14951
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 67 deletions.
28 changes: 11 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ rust-version = "1.75"
[dependencies]
# TODO it would be very nice to remove the "py-clone" feature as it can panic,
# but needs a bit of work to make sure it's not used in the codebase
pyo3 = { version = "0.22.5", features = ["generate-import-lib", "num-bigint", "py-clone"] }
pyo3 = { git = "https://github.com/pyo3/pyo3", branch = "release-0.23", features = ["generate-import-lib", "num-bigint", "py-clone"] }
regex = "1.11.1"
strum = { version = "0.26.3", features = ["derive"] }
strum_macros = "0.26.4"
Expand All @@ -46,7 +46,7 @@ base64 = "0.22.1"
num-bigint = "0.4.6"
python3-dll-a = "0.2.10"
uuid = "1.11.0"
jiter = { version = "0.7", features = ["python"] }
jiter = { git = "https://github.com/pydantic/jiter", branch = "dh/pyo3-0.23", features = ["python"] }
hex = "0.4.3"

[lib]
Expand Down Expand Up @@ -74,12 +74,12 @@ debug = true
strip = false

[dev-dependencies]
pyo3 = { version = "0.22.5", features = ["auto-initialize"] }
pyo3 = { git = "https://github.com/pyo3/pyo3", branch = "release-0.23", features = ["auto-initialize"] }

[build-dependencies]
version_check = "0.9.5"
# used where logic has to be version/distribution specific, e.g. pypy
pyo3-build-config = { version = "0.22.0" }
pyo3-build-config = { git = "https://github.com/pyo3/pyo3", branch = "release-0.23" }

[lints.clippy]
dbg_macro = "warn"
Expand Down
2 changes: 1 addition & 1 deletion src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ impl ValidationError {
.iter()
.map(|error| PyLineError::try_from(&error))
.collect::<PyResult<Vec<PyLineError>>>()?,
InputType::try_from(input_type)?,
InputType::try_from(input_type)?.into_py(cls.py()),
hide_input,
))
}
Expand Down
2 changes: 1 addition & 1 deletion src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ impl TzInfo {
}

#[allow(unused_variables)]
fn dst(&self, dt: &Bound<'_, PyAny>) -> Option<&PyDelta> {
fn dst(&self, dt: &Bound<'_, PyAny>) -> Option<&Bound<'_, PyDelta>> {
None
}

Expand Down
8 changes: 3 additions & 5 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,16 @@ pub(crate) fn iterate_mapping_items<'a, 'py>(
.items()
.map_err(|e| mapping_err(e, py, input))?
.iter()
.map_err(|e| mapping_err(e, py, input))?
.map(move |item| match item {
Ok(item) => item.extract().map_err(|_| {
.map(move |item| {
item.extract().map_err(|_| {
ValError::new(
ErrorType::MappingType {
error: MAPPING_TUPLE_ERROR.into(),
context: None,
},
input,
)
}),
Err(e) => Err(mapping_err(e, py, input)),
})
});
Ok(iterator)
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg_attr(has_coverage_attribute, feature(coverage_attribute))]
#![allow(deprecated)] // FIXME: just used during upgrading PyO3 to 0.23

extern crate core;

Expand Down
2 changes: 1 addition & 1 deletion src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ impl PathItem {
None
} else {
// otherwise, blindly try getitem on v since no better logic is realistic
py_any.get_item(self).ok()
py_any.get_item(self.to_object(py_any.py())).ok()
}
}

Expand Down
62 changes: 38 additions & 24 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::RefCell;
use std::fmt;
use std::sync::Mutex;

use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
Expand Down Expand Up @@ -366,18 +366,27 @@ impl From<bool> for WarningsMode {
}
}

#[derive(Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct CollectWarnings {
mode: WarningsMode,
warnings: RefCell<Option<Vec<String>>>,
// FIXME: mutex is to satisfy PyO3 0.23, we should be able to refactor this away
warnings: Mutex<Vec<String>>,
}

impl Clone for CollectWarnings {
fn clone(&self) -> Self {
Self {
mode: self.mode,
warnings: Mutex::new(self.warnings.lock().expect("lock poisoned").clone()),
}
}
}

impl CollectWarnings {
pub(crate) fn new(mode: WarningsMode) -> Self {
Self {
mode,
warnings: RefCell::new(None),
warnings: Mutex::new(Vec::new()),
}
}

Expand Down Expand Up @@ -443,41 +452,46 @@ impl CollectWarnings {
}

fn add_warning(&self, message: String) {
let mut op_warnings = self.warnings.borrow_mut();
if let Some(ref mut warnings) = *op_warnings {
warnings.push(message);
} else {
*op_warnings = Some(vec![message]);
}
self.warnings.lock().expect("lock poisoned").push(message)
}

pub fn final_check(&self, py: Python) -> PyResult<()> {
if self.mode == WarningsMode::None {
return Ok(());
}
match *self.warnings.borrow() {
Some(ref warnings) => {
let message = format!("Pydantic serializer warnings:\n {}", warnings.join("\n "));
if self.mode == WarningsMode::Warn {
let user_warning_type = py.import_bound("builtins")?.getattr("UserWarning")?;
PyErr::warn_bound(py, &user_warning_type, &message, 0)
} else {
Err(PydanticSerializationError::new_err(message))
}
}
_ => Ok(()),
let warnings = self.warnings.lock().expect("lock poisoned");

if warnings.is_empty() {
return Ok(());
}

let message = format!("Pydantic serializer warnings:\n {}", warnings.join("\n "));
if self.mode == WarningsMode::Warn {
let user_warning_type = py.import_bound("builtins")?.getattr("UserWarning")?;
PyErr::warn_bound(py, &user_warning_type, &message, 0)
} else {
Err(PydanticSerializationError::new_err(message))
}
}
}

#[derive(Default, Clone)]
#[derive(Default)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionState {
guard: RefCell<RecursionState>,
// FIXME: mutex is to satisfy PyO3 0.23, we should be able to refactor this away
guard: Mutex<RecursionState>,
}

impl Clone for SerRecursionState {
fn clone(&self) -> Self {
Self {
guard: Mutex::new(self.guard.lock().expect("lock poisoned").clone()),
}
}
}

impl ContainsRecursionState for &'_ Extra<'_> {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
f(&mut self.rec_guard.guard.borrow_mut())
f(&mut self.rec_guard.guard.lock().expect("lock poisoned"))
}
}
10 changes: 7 additions & 3 deletions src/serializers/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ trait FilterLogic<T: Eq + Copy> {
next_exclude = Some(exc_value);
}
} else if let Ok(exclude_set) = exclude.downcast::<PySet>() {
if exclude_set.contains(py_key)? || exclude_set.contains(intern!(exclude_set.py(), "__all__"))? {
if exclude_set.contains(py_key.to_object(exclude_set.py()))?
|| exclude_set.contains(intern!(exclude_set.py(), "__all__"))?
{
// index is in the exclude set, we return Ok(None) to omit this index
return Ok(None);
}
Expand Down Expand Up @@ -205,7 +207,9 @@ trait FilterLogic<T: Eq + Copy> {
return Ok(None);
}
} else if let Ok(include_set) = include.downcast::<PySet>() {
if include_set.contains(py_key)? || include_set.contains(intern!(include_set.py(), "__all__"))? {
if include_set.contains(py_key.to_object(include_set.py()))?
|| include_set.contains(intern!(include_set.py(), "__all__"))?
{
return Ok(Some((None, next_exclude)));
} else if !self.explicit_include(int_key) {
// if the index is not in include, include exists, AND it's not in schema include,
Expand Down Expand Up @@ -332,7 +336,7 @@ fn merge_all_value<'py>(
dict: &Bound<'py, PyDict>,
py_key: impl ToPyObject + Copy,
) -> PyResult<Option<Bound<'py, PyAny>>> {
let op_item_value = dict.get_item(py_key)?;
let op_item_value = dict.get_item(py_key.to_object(dict.py()))?;
let op_all_value = dict.get_item(intern!(dict.py(), "__all__"))?;

match (op_item_value, op_all_value) {
Expand Down
12 changes: 6 additions & 6 deletions src/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ use crate::input::Int;
use jiter::{cached_py_string, pystring_fast_new, StringCacheMode};

pub trait SchemaDict<'py> {
fn get_as<T>(&self, key: &Bound<'_, PyString>) -> PyResult<Option<T>>
fn get_as<T>(&self, key: &Bound<'py, PyString>) -> PyResult<Option<T>>
where
T: FromPyObject<'py>;

fn get_as_req<T>(&self, key: &Bound<'_, PyString>) -> PyResult<T>
fn get_as_req<T>(&self, key: &Bound<'py, PyString>) -> PyResult<T>
where
T: FromPyObject<'py>;
}

impl<'py> SchemaDict<'py> for Bound<'py, PyDict> {
fn get_as<T>(&self, key: &Bound<'_, PyString>) -> PyResult<Option<T>>
fn get_as<T>(&self, key: &Bound<'py, PyString>) -> PyResult<Option<T>>
where
T: FromPyObject<'py>,
{
Expand All @@ -31,7 +31,7 @@ impl<'py> SchemaDict<'py> for Bound<'py, PyDict> {
}
}

fn get_as_req<T>(&self, key: &Bound<'_, PyString>) -> PyResult<T>
fn get_as_req<T>(&self, key: &Bound<'py, PyString>) -> PyResult<T>
where
T: FromPyObject<'py>,
{
Expand All @@ -43,7 +43,7 @@ impl<'py> SchemaDict<'py> for Bound<'py, PyDict> {
}

impl<'py> SchemaDict<'py> for Option<&Bound<'py, PyDict>> {
fn get_as<T>(&self, key: &Bound<'_, PyString>) -> PyResult<Option<T>>
fn get_as<T>(&self, key: &Bound<'py, PyString>) -> PyResult<Option<T>>
where
T: FromPyObject<'py>,
{
Expand All @@ -54,7 +54,7 @@ impl<'py> SchemaDict<'py> for Option<&Bound<'py, PyDict>> {
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn get_as_req<T>(&self, key: &Bound<'_, PyString>) -> PyResult<T>
fn get_as_req<T>(&self, key: &Bound<'py, PyString>) -> PyResult<T>
where
T: FromPyObject<'py>,
{
Expand Down
3 changes: 2 additions & 1 deletion src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ impl Validator for ArgumentsValidator {
},
VarKwargsMode::UnpackedTypedDict => {
// Save to the remaining kwargs, we will validate as a single dict:
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
remaining_kwargs
.set_item(either_str.as_py_string(py, state.cache_str()), value.to_object(py))?;
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ impl Validator for DataclassArgsValidator {
Err(err) => return Err(err),
}
} else {
output_dict
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
output_dict.set_item(
either_str.as_py_string(py, state.cache_str()),
value.to_object(py),
)?;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ pub(crate) fn create_decimal<'py>(arg: &Bound<'py, PyAny>, input: impl ToErrorVa

fn handle_decimal_new_error(input: impl ToErrorValue, error: PyErr, decimal_exception: Bound<'_, PyAny>) -> ValError {
let py = decimal_exception.py();
if error.matches(py, decimal_exception) {
if error.matches(py, decimal_exception).unwrap_or(false) {
ValError::new(ErrorTypeDefaults::DecimalParsing, input)
} else if error.matches(py, PyTypeError::type_object_bound(py)) {
} else if error.matches(py, PyTypeError::type_object_bound(py)).unwrap_or(false) {
ValError::new(ErrorTypeDefaults::DecimalType, input)
} else {
ValError::InternalErr(error)
Expand Down

0 comments on commit cc14951

Please sign in to comment.