From 0789aab43ed58cecb95ef906c4f507720287cb84 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 8 Aug 2024 08:50:16 -0400 Subject: [PATCH 01/15] Initial work on tagged union serializer - just to_python for now first pass, nothing working yet -- tagged union serializer still very much a wip, but wanted to checkpoint another checkpoint cleaning --- src/common/discriminator.rs | 49 ++++++ src/common/mod.rs | 1 + src/lib.rs | 1 + src/serializers/shared.rs | 3 +- src/serializers/type_serializers/union.rs | 202 +++++++++++++++++++++- src/validators/union.rs | 47 +---- 6 files changed, 249 insertions(+), 54 deletions(-) create mode 100644 src/common/discriminator.rs create mode 100644 src/common/mod.rs diff --git a/src/common/discriminator.rs b/src/common/discriminator.rs new file mode 100644 index 000000000..03e8da23c --- /dev/null +++ b/src/common/discriminator.rs @@ -0,0 +1,49 @@ +use pyo3::prelude::*; +use pyo3::types::PyString; +use pyo3::{PyTraverseError, PyVisit}; + +use crate::lookup_key::LookupKey; +use crate::py_gc::PyGcTraverse; + +#[derive(Debug, Clone)] +pub enum Discriminator { + /// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases + LookupKey(LookupKey), + /// call a function to find the tag to use + Function(PyObject), + /// Custom discriminator specifically for the root `Schema` union in self-schema + SelfSchema, +} + +impl Discriminator { + pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult { + if raw.is_callable() { + return Ok(Self::Function(raw.to_object(py))); + } else if let Ok(py_str) = raw.downcast::() { + if py_str.to_str()? == "self-schema-discriminator" { + return Ok(Self::SelfSchema); + } + } + + let lookup_key = LookupKey::from_py(py, raw, None)?; + Ok(Self::LookupKey(lookup_key)) + } + + pub fn to_string_py(&self, py: Python) -> PyResult { + match self { + Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)), + Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()), + Self::SelfSchema => Ok("self-schema".to_string()), + } + } +} + +impl PyGcTraverse for Discriminator { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + match self { + Self::Function(obj) => visit.call(obj)?, + Self::LookupKey(_) | Self::SelfSchema => {} + } + Ok(()) + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 000000000..dd75e9167 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1 @@ +pub(crate) mod discriminator; diff --git a/src/lib.rs b/src/lib.rs index eb598424b..7549a7662 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ mod py_gc; mod argument_markers; mod build_tools; +mod common; mod definitions; mod errors; mod input; diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index b9e0a727d..280ccbec7 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -88,7 +88,6 @@ combined_serializer! { // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer find_only: { - super::type_serializers::union::TaggedUnionBuilder; super::type_serializers::other::ChainBuilder; super::type_serializers::other::CustomErrorBuilder; super::type_serializers::other::CallBuilder; @@ -138,6 +137,7 @@ combined_serializer! { Json: super::type_serializers::json::JsonSerializer; JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer; Union: super::type_serializers::union::UnionSerializer; + TaggedUnion: super::type_serializers::union::TaggedUnionSerializer; Literal: super::type_serializers::literal::LiteralSerializer; Enum: super::type_serializers::enum_::EnumSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; @@ -246,6 +246,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit), CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::TaggedUnion(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 51fc07e8d..028172268 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -1,10 +1,13 @@ +use ahash::AHashMap as HashMap; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; use crate::build_tools::py_schema_err; +use crate::common::discriminator::Discriminator; use crate::definitions::DefinitionsBuilder; +use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; @@ -180,9 +183,15 @@ impl TypeSerializer for UnionSerializer { } } -pub struct TaggedUnionBuilder; +#[derive(Debug, Clone)] +pub struct TaggedUnionSerializer { + discriminator: Discriminator, + lookup: HashMap, + choices: Vec, + name: String, +} -impl BuildSerializer for TaggedUnionBuilder { +impl BuildSerializer for TaggedUnionSerializer { const EXPECTED_TYPE: &'static str = "tagged-union"; fn build( @@ -190,14 +199,191 @@ impl BuildSerializer for TaggedUnionBuilder { config: Option<&Bound<'_, PyDict>>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_choices: Bound<'_, PyDict> = schema.get_as_req(intern!(schema.py(), "choices"))?; - let mut choices: Vec = Vec::with_capacity(schema_choices.len()); + let py = schema.py(); + let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?; + + let choice_list: Bound = schema.get_as_req(intern!(py, "choices"))?; + let mut lookup: HashMap = HashMap::with_capacity(choice_list.len()); + let mut choices: Vec = Vec::with_capacity(choice_list.len()); + + for (choice_key, chice_schema) in choice_list { + let serializer = CombinedSerializer::build(chice_schema.downcast()?, config, definitions).unwrap(); + choices.push(serializer.clone()); + lookup.insert(choice_key.extract::()?, serializer); + } + + let descr = choices + .iter() + .map(TypeSerializer::get_name) + .collect::>() + .join(", "); + + Ok(Self { + discriminator, + lookup, + choices, + name: format!("TaggedUnion[{descr}]"), + } + .into()) + } +} + +impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, lookup }); + +impl TypeSerializer for TaggedUnionSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + match &self.discriminator { + Discriminator::LookupKey(lookup_key) => { + let discriminator_value = match lookup_key { + LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok(), + _ => None, + }; + + if let Some(tag) = discriminator_value { + if let Ok(tag_str) = tag.extract::() { + if let Some(serializer) = self.lookup.get(&tag_str) { + return serializer.to_python(value, include, exclude, &new_extra); + } + } else { + return Err(self.tag_not_found()); + } + } + + let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); + if let Ok(s) = basic_union_ser { + return s.to_python(value, include, exclude, extra); + } + } + Discriminator::Function(func) => { + // try calling the method directly on the object + let discriminator_value = func.call1(py, (value,)).ok().or_else(|| { + // Try converting object to a dict, might be more compatible with poorly defined callable discriminator + value + .call_method0(intern!(py, "dict")) + .and_then(|v| func.call1(py, (v.to_object(py),))) + .ok() + }); + + if let Some(tag) = discriminator_value { + if let Ok(tag_str) = tag.extract::(py) { + if let Some(serializer) = self.lookup.get(&tag_str) { + return serializer.to_python(value, include, exclude, &new_extra); + } + } else { + return Err(self.tag_not_found()); + } + } + + let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); + if let Ok(s) = basic_union_ser { + return s.to_python(value, include, exclude, extra); + } + } + Discriminator::SelfSchema => { + // not really sure about this case, but it's here for completeness + for comb_serializer in &self.choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(value.py()) { + true => (), + false => return Err(err), + }, + } + } + } + } + + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + // TODO: implement this + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + for comb_serializer in &self.choices { + match comb_serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(key.py()) { + true => (), + false => return Err(err), + }, + } + } + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + for comb_serializer in &self.choices { + match comb_serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(key.py()) { + true => (), + false => return Err(err), + }, + } + } + } - for (_, value) in schema_choices { - if let Ok(choice_schema) = value.downcast::() { - choices.push(CombinedSerializer::build(choice_schema, config, definitions)?); + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; + infer_json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + // TODO: implement this + + let py = value.py(); + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + for comb_serializer in &self.choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { + true => (), + false => return Err(py_err_se_err(err)), + }, } } - UnionSerializer::from_choices(choices) + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + for comb_serializer in &self.choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { + true => (), + false => return Err(py_err_se_err(err)), + }, + } + } + } + + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + &self.name + } +} + +impl TaggedUnionSerializer { + fn tag_not_found(&self) -> PyErr { + PydanticSerializationUnexpectedValue::new_err(Some("Tag not found in tagged union for value: {:?}".to_string())) } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 6fb9fb070..9a9c6c372 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use std::str::FromStr; +use crate::py_gc::PyGcTraverse; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -8,10 +9,9 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; +use crate::common::discriminator::Discriminator; use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, Input, ValidatedDict}; -use crate::lookup_key::LookupKey; -use crate::py_gc::PyGcTraverse; use crate::tools::SchemaDict; use super::custom_error::CustomError; @@ -295,49 +295,6 @@ impl<'a> MaybeErrors<'a> { } } -#[derive(Debug, Clone)] -enum Discriminator { - /// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases - LookupKey(LookupKey), - /// call a function to find the tag to use - Function(PyObject), - /// Custom discriminator specifically for the root `Schema` union in self-schema - SelfSchema, -} - -impl Discriminator { - fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult { - if raw.is_callable() { - return Ok(Self::Function(raw.to_object(py))); - } else if let Ok(py_str) = raw.downcast::() { - if py_str.to_str()? == "self-schema-discriminator" { - return Ok(Self::SelfSchema); - } - } - - let lookup_key = LookupKey::from_py(py, raw, None)?; - Ok(Self::LookupKey(lookup_key)) - } - - fn to_string_py(&self, py: Python) -> PyResult { - match self { - Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)), - Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()), - Self::SelfSchema => Ok("self-schema".to_string()), - } - } -} - -impl PyGcTraverse for Discriminator { - fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { - match self { - Self::Function(obj) => visit.call(obj)?, - Self::LookupKey(_) | Self::SelfSchema => {} - } - Ok(()) - } -} - #[derive(Debug)] pub struct TaggedUnionValidator { discriminator: Discriminator, From 80839e071da36ac96bbb962b2152c15bb48a7906 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 8 Aug 2024 11:39:10 -0400 Subject: [PATCH 02/15] removing self schema according to https://github.com/pydantic/pydantic-core/commit/5efeaf9fe6cb5e57c77e42644387afa781b127e0 --- src/common/discriminator.rs | 10 +----- src/input/input_abstract.rs | 4 --- src/input/input_json.rs | 4 --- src/input/input_python.rs | 7 ----- src/input/input_string.rs | 3 -- src/serializers/type_serializers/union.rs | 12 -------- src/validators/union.rs | 37 ----------------------- 7 files changed, 1 insertion(+), 76 deletions(-) diff --git a/src/common/discriminator.rs b/src/common/discriminator.rs index 03e8da23c..8995fdcfc 100644 --- a/src/common/discriminator.rs +++ b/src/common/discriminator.rs @@ -1,5 +1,4 @@ use pyo3::prelude::*; -use pyo3::types::PyString; use pyo3::{PyTraverseError, PyVisit}; use crate::lookup_key::LookupKey; @@ -11,18 +10,12 @@ pub enum Discriminator { LookupKey(LookupKey), /// call a function to find the tag to use Function(PyObject), - /// Custom discriminator specifically for the root `Schema` union in self-schema - SelfSchema, } impl Discriminator { pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult { if raw.is_callable() { return Ok(Self::Function(raw.to_object(py))); - } else if let Ok(py_str) = raw.downcast::() { - if py_str.to_str()? == "self-schema-discriminator" { - return Ok(Self::SelfSchema); - } } let lookup_key = LookupKey::from_py(py, raw, None)?; @@ -33,7 +26,6 @@ impl Discriminator { match self { Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)), Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()), - Self::SelfSchema => Ok("self-schema".to_string()), } } } @@ -42,7 +34,7 @@ impl PyGcTraverse for Discriminator { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { match self { Self::Function(obj) => visit.call(obj)?, - Self::LookupKey(_) | Self::SelfSchema => {} + Self::LookupKey(_) => {} } Ok(()) } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index b0e058d9b..ea1fe0e56 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -237,7 +237,6 @@ pub trait ValidatedDict<'py> { where Self: 'a; fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult)>>; - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>>; // FIXME this is a bit of a leaky abstraction fn is_py_get_attr(&self) -> bool { false @@ -280,9 +279,6 @@ impl<'py> ValidatedDict<'py> for Never { fn get_item<'k>(&self, _key: &'k LookupKey) -> ValResult)>> { unreachable!() } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - unreachable!() - } fn iterate<'a, R>( &'a self, _consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 3adc36ba6..824333ab8 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -476,10 +476,6 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> { key.json_get(self) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - None - } - fn iterate<'a, R>( &'a self, consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 7840a825a..048fb5c27 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -775,13 +775,6 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> { matches!(self, Self::GetAttr(..)) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - match self { - Self::Dict(dict) => Some(dict), - _ => None, - } - } - fn iterate<'a, R>( &'a self, consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 3ef1b58ce..ea0482b37 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -284,9 +284,6 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> { fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult)>> { key.py_get_string_mapping_item(&self.0) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - None - } fn iterate<'a, R>( &'a self, consumer: impl super::ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 028172268..0ba39fd26 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -290,18 +290,6 @@ impl TypeSerializer for TaggedUnionSerializer { return s.to_python(value, include, exclude, extra); } } - Discriminator::SelfSchema => { - // not really sure about this case, but it's here for completeness - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => return Err(err), - }, - } - } - } } extra.warnings.on_fallback_py(self.get_name(), value, extra)?; diff --git a/src/validators/union.rs b/src/validators/union.rs index 9a9c6c372..aa1c37883 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -345,11 +345,6 @@ impl BuildValidator for TaggedUnionValidator { let key = intern!(py, "from_attributes"); let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true); - let descr = match discriminator { - Discriminator::SelfSchema => "self-schema".to_string(), - _ => descr, - }; - Ok(Self { discriminator, lookup, @@ -393,9 +388,6 @@ impl Validator for TaggedUnionValidator { self.find_call_validator(py, tag.bind(py), input, state) } } - Discriminator::SelfSchema => { - self.find_call_validator(py, self.self_schema_tag(py, input, state)?.as_any(), input, state) - } } } @@ -405,35 +397,6 @@ impl Validator for TaggedUnionValidator { } impl TaggedUnionValidator { - fn self_schema_tag<'py>( - &self, - py: Python<'py>, - input: &(impl Input<'py> + ?Sized), - state: &mut ValidationState<'_, 'py>, - ) -> ValResult> { - let dict = input.strict_dict()?; - let dict = dict.as_py_dict().expect("self schema is always a Python dictionary"); - let tag = match dict.get_item(intern!(py, "type"))? { - Some(t) => t.downcast_into::()?, - None => return Err(self.tag_not_found(input)), - }; - let tag = tag.to_str()?; - // custom logic to distinguish between different function and tuple schemas - if tag == "function" { - let Some(mode) = dict.get_item(intern!(py, "mode"))? else { - return Err(self.tag_not_found(input)); - }; - let tag = match mode.validate_str(true, false)?.into_inner().as_cow()?.as_ref() { - "plain" => Ok(intern!(py, "function-plain").to_owned()), - "wrap" => Ok(intern!(py, "function-wrap").to_owned()), - _ => Ok(intern!(py, "function").to_owned()), - }; - tag - } else { - Ok(state.maybe_cached_str(py, tag)) - } - } - fn find_call_validator<'py>( &self, py: Python<'py>, From c69263d9163f94dc18e2b45147cdaccc9d90c6fd Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 8 Aug 2024 14:00:53 -0400 Subject: [PATCH 03/15] finish logic up --- src/serializers/type_serializers/union.rs | 179 ++++++++++++---------- 1 file changed, 97 insertions(+), 82 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 0ba39fd26..1fef9214e 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -243,53 +243,35 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - match &self.discriminator { - Discriminator::LookupKey(lookup_key) => { - let discriminator_value = match lookup_key { - LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok(), - _ => None, - }; - - if let Some(tag) = discriminator_value { - if let Ok(tag_str) = tag.extract::() { - if let Some(serializer) = self.lookup.get(&tag_str) { - return serializer.to_python(value, include, exclude, &new_extra); - } - } else { - return Err(self.tag_not_found()); + let discriminator_value = self.get_discriminator_value(value); + + if let Some(tag) = discriminator_value { + if let Ok(tag_str) = tag.extract::(py) { + if let Some(serializer) = self.lookup.get(&tag_str) { + match serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.to_python(value, include, exclude, &new_extra); + } + } + false => return Err(err), + }, } + } else { + return Err(self.tag_not_found()); } - - let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); - if let Ok(s) = basic_union_ser { - return s.to_python(value, include, exclude, extra); - } + } else { + return Err(self.tag_not_found()); } - Discriminator::Function(func) => { - // try calling the method directly on the object - let discriminator_value = func.call1(py, (value,)).ok().or_else(|| { - // Try converting object to a dict, might be more compatible with poorly defined callable discriminator - value - .call_method0(intern!(py, "dict")) - .and_then(|v| func.call1(py, (v.to_object(py),))) - .ok() - }); - - if let Some(tag) = discriminator_value { - if let Ok(tag_str) = tag.extract::(py) { - if let Some(serializer) = self.lookup.get(&tag_str) { - return serializer.to_python(value, include, exclude, &new_extra); - } - } else { - return Err(self.tag_not_found()); - } - } + } - let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); - if let Ok(s) = basic_union_ser { - return s.to_python(value, include, exclude, extra); - } - } + // Fallback processing + let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); + if let Ok(s) = basic_union_ser { + return s.to_python(value, include, exclude, extra); } extra.warnings.on_fallback_py(self.get_name(), value, extra)?; @@ -297,31 +279,39 @@ impl TypeSerializer for TaggedUnionSerializer { } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - // TODO: implement this let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - for comb_serializer in &self.choices { - match comb_serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(key.py()) { - true => (), - false => return Err(err), - }, - } - } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(key.py()) { - true => (), - false => return Err(err), - }, + + let py = key.py(); + + let discriminator_value = self.get_discriminator_value(key); + + if let Some(tag) = discriminator_value { + if let Ok(tag_str) = tag.extract::(py) { + if let Some(serializer) = self.lookup.get(&tag_str) { + match serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.json_key(key, &new_extra); + } + } + } + } else { + return Err(self.tag_not_found()); } + } else { + return Err(self.tag_not_found()); } } + // Fallback processing + let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); + if let Ok(s) = basic_union_ser { + return s.json_key(key, extra); + } + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) } @@ -334,33 +324,41 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - // TODO: implement this - let py = value.py(); let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(py) { - true => (), - false => return Err(py_err_se_err(err)), - }, - } - } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(py) { - true => (), - false => return Err(py_err_se_err(err)), - }, + + let discriminator_value = self.get_discriminator_value(value); + + if let Some(tag) = discriminator_value { + if let Ok(tag_str) = tag.extract::(py) { + if let Some(selected_serializer) = self.lookup.get(&tag_str) { + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => return Err(py_err_se_err(err)), + } + } + } + } + } else { + return Err(py_err_se_err(self.tag_not_found())); } + } else { + return Err(py_err_se_err(self.tag_not_found())); } } + // Fallback processing + let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); + if let Ok(s) = basic_union_ser { + return s.serde_serialize(value, serializer, include, exclude, extra); + } + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; infer_serialize(value, serializer, include, exclude, extra) } @@ -371,6 +369,23 @@ impl TypeSerializer for TaggedUnionSerializer { } impl TaggedUnionSerializer { + fn get_discriminator_value(&self, value: &Bound<'_, PyAny>) -> Option> { + let py = value.py(); + match &self.discriminator { + Discriminator::LookupKey(lookup_key) => match lookup_key { + LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok().map(|obj| obj.to_object(py)), + _ => None, + }, + Discriminator::Function(func) => func.call1(py, (value,)).ok().or_else(|| { + // Try converting object to a dict, might be more compatible with poorly defined callable discriminator + value + .call_method0(intern!(py, "dict")) + .and_then(|v| func.call1(py, (v.to_object(py),))) + .ok() + }), + } + } + fn tag_not_found(&self) -> PyErr { PydanticSerializationUnexpectedValue::new_err(Some("Tag not found in tagged union for value: {:?}".to_string())) } From 4d7cb3703fdd4f2663fd723d6bc52e3fb722a3ea Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 8 Aug 2024 14:28:56 -0400 Subject: [PATCH 04/15] use to string conversion --- src/serializers/type_serializers/union.rs | 77 ++++++++++------------- 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 1fef9214e..c0caea5a4 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -206,10 +206,10 @@ impl BuildSerializer for TaggedUnionSerializer { let mut lookup: HashMap = HashMap::with_capacity(choice_list.len()); let mut choices: Vec = Vec::with_capacity(choice_list.len()); - for (choice_key, chice_schema) in choice_list { - let serializer = CombinedSerializer::build(chice_schema.downcast()?, config, definitions).unwrap(); + for (choice_key, choice_schema) in choice_list { + let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions).unwrap(); choices.push(serializer.clone()); - lookup.insert(choice_key.extract::()?, serializer); + lookup.insert(choice_key.to_string(), serializer); } let descr = choices @@ -246,22 +246,19 @@ impl TypeSerializer for TaggedUnionSerializer { let discriminator_value = self.get_discriminator_value(value); if let Some(tag) = discriminator_value { - if let Ok(tag_str) = tag.extract::(py) { - if let Some(serializer) = self.lookup.get(&tag_str) { - match serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(py) { - true => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - return serializer.to_python(value, include, exclude, &new_extra); - } + let tag_str = tag.to_string(); + if let Some(serializer) = self.lookup.get(&tag_str) { + match serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.to_python(value, include, exclude, &new_extra); } - false => return Err(err), - }, - } - } else { - return Err(self.tag_not_found()); + } + false => return Err(err), + }, } } else { return Err(self.tag_not_found()); @@ -282,24 +279,19 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let py = key.py(); - let discriminator_value = self.get_discriminator_value(key); if let Some(tag) = discriminator_value { - if let Ok(tag_str) = tag.extract::(py) { - if let Some(serializer) = self.lookup.get(&tag_str) { - match serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - return serializer.json_key(key, &new_extra); - } + let tag_str = tag.to_string(); + if let Some(serializer) = self.lookup.get(&tag_str) { + match serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.json_key(key, &new_extra); } } - } else { - return Err(self.tag_not_found()); } } else { return Err(self.tag_not_found()); @@ -331,22 +323,19 @@ impl TypeSerializer for TaggedUnionSerializer { let discriminator_value = self.get_discriminator_value(value); if let Some(tag) = discriminator_value { - if let Ok(tag_str) = tag.extract::(py) { - if let Some(selected_serializer) = self.lookup.get(&tag_str) { - match selected_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - match selected_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => return Err(py_err_se_err(err)), - } + let tag_str = tag.to_string(); + if let Some(selected_serializer) = self.lookup.get(&tag_str) { + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => return Err(py_err_se_err(err)), } } } - } else { - return Err(py_err_se_err(self.tag_not_found())); } } else { return Err(py_err_se_err(self.tag_not_found())); From 37954ebfe0c69382de435f988a7fbcdd29371aa5 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 9 Aug 2024 14:13:56 -0400 Subject: [PATCH 05/15] remove tag not found errors --- src/serializers/type_serializers/union.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index c0caea5a4..920c2d997 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -260,12 +260,9 @@ impl TypeSerializer for TaggedUnionSerializer { false => return Err(err), }, } - } else { - return Err(self.tag_not_found()); } } - // Fallback processing let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); if let Ok(s) = basic_union_ser { return s.to_python(value, include, exclude, extra); @@ -293,12 +290,9 @@ impl TypeSerializer for TaggedUnionSerializer { } } } - } else { - return Err(self.tag_not_found()); } } - // Fallback processing let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); if let Ok(s) = basic_union_ser { return s.json_key(key, extra); @@ -337,12 +331,9 @@ impl TypeSerializer for TaggedUnionSerializer { } } } - } else { - return Err(py_err_se_err(self.tag_not_found())); } } - // Fallback processing let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); if let Ok(s) = basic_union_ser { return s.serde_serialize(value, serializer, include, exclude, extra); @@ -374,8 +365,4 @@ impl TaggedUnionSerializer { }), } } - - fn tag_not_found(&self) -> PyErr { - PydanticSerializationUnexpectedValue::new_err(Some("Tag not found in tagged union for value: {:?}".to_string())) - } } From 7b189aa54a8d244a6a736a41ffbbe8c8c8eb0676 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 9 Aug 2024 14:15:40 -0400 Subject: [PATCH 06/15] another simplification --- src/serializers/type_serializers/union.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 920c2d997..e871610fd 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -243,9 +243,7 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let discriminator_value = self.get_discriminator_value(value); - - if let Some(tag) = discriminator_value { + if let Some(tag) = self.get_discriminator_value(value) { let tag_str = tag.to_string(); if let Some(serializer) = self.lookup.get(&tag_str) { match serializer.to_python(value, include, exclude, &new_extra) { @@ -276,9 +274,7 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let discriminator_value = self.get_discriminator_value(key); - - if let Some(tag) = discriminator_value { + if let Some(tag) = self.get_discriminator_value(key) { let tag_str = tag.to_string(); if let Some(serializer) = self.lookup.get(&tag_str) { match serializer.json_key(key, &new_extra) { @@ -314,9 +310,7 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let discriminator_value = self.get_discriminator_value(value); - - if let Some(tag) = discriminator_value { + if let Some(tag) = self.get_discriminator_value(value) { let tag_str = tag.to_string(); if let Some(selected_serializer) = self.lookup.get(&tag_str) { match selected_serializer.to_python(value, include, exclude, &new_extra) { From 37b4dbbc486c4d7b4d5659591336628bf83770a1 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 9 Aug 2024 14:29:48 -0400 Subject: [PATCH 07/15] adding tagged union test --- tests/serializers/test_union.py | 49 +++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index ee5ed3fc4..a1c5ec08b 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -626,3 +626,52 @@ def test_union_serializer_picks_exact_type_over_subclass_json( ) assert s.to_python(input_value, mode='json') == expected_value assert s.to_json(input_value) == json.dumps(expected_value).encode() + + +def test_tagged_union() -> None: + @dataclasses.dataclass + class ModelA: + field: int + tag: Literal['a'] = 'a' + + @dataclasses.dataclass + class ModelB: + field: int + tag: Literal['b'] = 'b' + + s = SchemaSerializer( + core_schema.tagged_union_schema( + choices={ + 'a': core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', + [ + core_schema.dataclass_field(name='field', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='tag', schema=core_schema.literal_schema(['a'])), + ], + ), + ['field', 'tag'], + ), + 'b': core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', + [ + core_schema.dataclass_field(name='field', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='tag', schema=core_schema.literal_schema(['b'])), + ], + ), + ['field', 'tag'], + ), + }, + discriminator='tag', + ) + ) + + assert 'TaggedUnionSerializer' in repr(s) + + model_a = ModelA(field=1) + model_b = ModelB(field=1) + assert s.to_python(model_a) == {'field': 1, 'tag': 'a'} + assert s.to_python(model_b) == {'field': 1, 'tag': 'b'} From e3aceeb2fa184d20ac58fd2942fb43fe9e1ef0bb Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 12 Aug 2024 12:43:47 -0400 Subject: [PATCH 08/15] fixing merge conflict --- src/serializers/type_serializers/union.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 070b3b720..a54c14246 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -8,6 +8,8 @@ use std::borrow::Cow; use crate::build_tools::py_schema_err; use crate::common::discriminator::Discriminator; use crate::definitions::DefinitionsBuilder; +use crate::lookup_key::LookupKey; +use crate::serializers::type_serializers::py_err_se_err; use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY}; use crate::PydanticSerializationUnexpectedValue; From a04e297690f6f0834be6e840506a772e1599961d Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 12 Aug 2024 12:47:14 -0400 Subject: [PATCH 09/15] moving union const --- src/common/mod.rs | 2 +- src/common/{discriminator.rs => union.rs} | 2 ++ src/serializers/type_serializers/union.rs | 10 +++++----- src/tools.rs | 2 -- src/validators/union.rs | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) rename src/common/{discriminator.rs => union.rs} (95%) diff --git a/src/common/mod.rs b/src/common/mod.rs index dd75e9167..11f2e1ece 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1 +1 @@ -pub(crate) mod discriminator; +pub(crate) mod union; diff --git a/src/common/discriminator.rs b/src/common/union.rs similarity index 95% rename from src/common/discriminator.rs rename to src/common/union.rs index 8995fdcfc..17fe9ad90 100644 --- a/src/common/discriminator.rs +++ b/src/common/union.rs @@ -39,3 +39,5 @@ impl PyGcTraverse for Discriminator { Ok(()) } } + +pub(crate) const SMALL_UNION_THRESHOLD: usize = 4; diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index a54c14246..2715c5fea 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -6,11 +6,11 @@ use smallvec::SmallVec; use std::borrow::Cow; use crate::build_tools::py_schema_err; -use crate::common::discriminator::Discriminator; +use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; use crate::lookup_key::LookupKey; use crate::serializers::type_serializers::py_err_se_err; -use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY}; +use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; use super::{ @@ -83,7 +83,7 @@ impl TypeSerializer for UnionSerializer { // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -118,7 +118,7 @@ impl TypeSerializer for UnionSerializer { fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in &self.choices { match comb_serializer.json_key(key, &new_extra) { @@ -161,7 +161,7 @@ impl TypeSerializer for UnionSerializer { let py = value.py(); let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { diff --git a/src/tools.rs b/src/tools.rs index 121ae3880..adf64c91a 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -146,5 +146,3 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach pystring_fast_new(py, s, ascii_only) } } - -pub(crate) const UNION_ERR_SMALLVEC_CAPACITY: usize = 4; diff --git a/src/validators/union.rs b/src/validators/union.rs index 70057fad9..747f6a0cf 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -9,10 +9,10 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; -use crate::common::discriminator::Discriminator; +use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, Input, ValidatedDict}; -use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY}; +use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; @@ -249,7 +249,7 @@ struct ChoiceLineErrors<'a> { enum MaybeErrors<'a> { Custom(&'a CustomError), - Errors(SmallVec<[ChoiceLineErrors<'a>; UNION_ERR_SMALLVEC_CAPACITY]>), + Errors(SmallVec<[ChoiceLineErrors<'a>; SMALL_UNION_THRESHOLD]>), } impl<'a> MaybeErrors<'a> { From bf4540943cb76352311bb463910542cb864c0714 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 12 Aug 2024 20:26:04 -0400 Subject: [PATCH 10/15] rename map --- src/serializers/type_serializers/union.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 2715c5fea..20533ce1c 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -221,11 +221,11 @@ impl BuildSerializer for TaggedUnionSerializer { let py = schema.py(); let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?; - let choice_list: Bound = schema.get_as_req(intern!(py, "choices"))?; - let mut lookup: HashMap = HashMap::with_capacity(choice_list.len()); - let mut choices: Vec = Vec::with_capacity(choice_list.len()); + let choices_map: Bound = schema.get_as_req(intern!(py, "choices"))?; + let mut lookup: HashMap = HashMap::with_capacity(choices_map.len()); + let mut choices: Vec = Vec::with_capacity(choices_map.len()); - for (choice_key, choice_schema) in choice_list { + for (choice_key, choice_schema) in choices_map { let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions).unwrap(); choices.push(serializer.clone()); lookup.insert(choice_key.to_string(), serializer); From 27faf1a8e536092636026791056731c2c0cba38b Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 13 Aug 2024 12:14:27 -0400 Subject: [PATCH 11/15] adding retry w lax check --- src/serializers/type_serializers/union.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 6719b826b..b0ea3f6fc 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -359,6 +359,10 @@ impl TypeSerializer for TaggedUnionSerializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) + } } impl TaggedUnionSerializer { From dacb9ff1c101e57fb18f7f65f410adb5defcb374 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 14 Aug 2024 21:27:29 -0400 Subject: [PATCH 12/15] benchmark --- src/serializers/type_serializers/union.rs | 116 ++++++++++++---------- 1 file changed, 62 insertions(+), 54 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index b0ea3f6fc..f89b7d4e9 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -80,39 +80,7 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - // try the serializers in left to right order with error_on fallback=true - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), - }, - } - } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), - }, - } - } - } - - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - - extra.warnings.on_fallback_py(self.get_name(), value, extra)?; - infer_to_python(value, include, exclude, extra) + to_python(value, include, exclude, extra, &self.choices, self.get_name()) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -202,10 +170,55 @@ impl TypeSerializer for UnionSerializer { } } -#[derive(Debug, Clone)] +fn to_python( + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, +) -> PyResult { + // try the serializers in left to right order with error_on fallback=true + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + + for comb_serializer in choices.clone() { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(value.py()) { + true => (), + false => errors.push(err), + }, + } + } + + let retry_with_lax_check = choices.clone().into_iter().any(CombinedSerializer::retry_with_lax_check); + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(value.py()) { + true => (), + false => errors.push(err), + }, + } + } + } + + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + + extra.warnings.on_fallback_py(name, value, extra)?; + infer_to_python(value, include, exclude, extra) +} + +#[derive(Debug)] pub struct TaggedUnionSerializer { discriminator: Discriminator, - lookup: HashMap, + lookup: HashMap, choices: Vec, name: String, } @@ -221,14 +234,15 @@ impl BuildSerializer for TaggedUnionSerializer { let py = schema.py(); let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?; + // TODO: guarantee at least 1 choice let choices_map: Bound = schema.get_as_req(intern!(py, "choices"))?; - let mut lookup: HashMap = HashMap::with_capacity(choices_map.len()); - let mut choices: Vec = Vec::with_capacity(choices_map.len()); + let mut lookup = HashMap::with_capacity(choices_map.len()); + let mut choices = Vec::with_capacity(choices_map.len()); - for (choice_key, choice_schema) in choices_map { - let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions).unwrap(); - choices.push(serializer.clone()); - lookup.insert(choice_key.to_string(), serializer); + for (idx, (choice_key, choice_schema)) in choices_map.into_iter().enumerate() { + let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions)?; + choices.push(serializer); + lookup.insert(choice_key.to_string(), idx); } let descr = choices @@ -265,13 +279,13 @@ impl TypeSerializer for TaggedUnionSerializer { if let Some(tag) = self.get_discriminator_value(value) { let tag_str = tag.to_string(); if let Some(serializer) = self.lookup.get(&tag_str) { - match serializer.to_python(value, include, exclude, &new_extra) { + match self.choices[*serializer].to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(py) { true => { if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; - return serializer.to_python(value, include, exclude, &new_extra); + return self.choices[*serializer].to_python(value, include, exclude, &new_extra); } } false => return Err(err), @@ -280,13 +294,7 @@ impl TypeSerializer for TaggedUnionSerializer { } } - let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); - if let Ok(s) = basic_union_ser { - return s.to_python(value, include, exclude, extra); - } - - extra.warnings.on_fallback_py(self.get_name(), value, extra)?; - infer_to_python(value, include, exclude, extra) + to_python(value, include, exclude, extra, &self.choices, self.get_name()) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -296,12 +304,12 @@ impl TypeSerializer for TaggedUnionSerializer { if let Some(tag) = self.get_discriminator_value(key) { let tag_str = tag.to_string(); if let Some(serializer) = self.lookup.get(&tag_str) { - match serializer.json_key(key, &new_extra) { + match self.choices[*serializer].json_key(key, &new_extra) { Ok(v) => return Ok(v), Err(_) => { if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; - return serializer.json_key(key, &new_extra); + return self.choices[*serializer].json_key(key, &new_extra); } } } @@ -332,12 +340,12 @@ impl TypeSerializer for TaggedUnionSerializer { if let Some(tag) = self.get_discriminator_value(value) { let tag_str = tag.to_string(); if let Some(selected_serializer) = self.lookup.get(&tag_str) { - match selected_serializer.to_python(value, include, exclude, &new_extra) { + match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) { Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), Err(_) => { if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; - match selected_serializer.to_python(value, include, exclude, &new_extra) { + match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) { Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), Err(err) => return Err(py_err_se_err(err)), } From 7f2b1dac03d5eb283739c149ada5f4386a42b748 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 14 Aug 2024 21:47:25 -0400 Subject: [PATCH 13/15] missing lifteime and pygctraverse stuff --- src/serializers/type_serializers/union.rs | 261 +++++++++++++--------- 1 file changed, 153 insertions(+), 108 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index f89b7d4e9..a1d2daa2b 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -72,134 +72,124 @@ impl UnionSerializer { impl_py_gc_traverse!(UnionSerializer { choices }); -impl TypeSerializer for UnionSerializer { - fn to_python( - &self, - value: &Bound<'_, PyAny>, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - ) -> PyResult { - to_python(value, include, exclude, extra, &self.choices, self.get_name()) - } +fn to_python( + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, + retry_with_lax_check: bool, +) -> PyResult { + // try the serializers in left to right order with error_on fallback=true + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + for comb_serializer in choices.clone() { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(value.py()) { + true => (), + false => errors.push(err), + }, + } + } - for comb_serializer in &self.choices { - match comb_serializer.json_key(key, &new_extra) { + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(key.py()) { + Err(err) => match err.is_instance_of::(value.py()) { true => (), false => errors.push(err), }, } } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(key.py()) { - true => (), - false => errors.push(err), - }, - } - } - } - - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } + } - extra.warnings.on_fallback_py(self.get_name(), key, extra)?; - infer_json_key(key, extra) + for err in &errors { + extra.warnings.custom_warning(err.to_string()); } - fn serde_serialize( - &self, - value: &Bound<'_, PyAny>, - serializer: S, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - ) -> Result { - let py = value.py(); - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + extra.warnings.on_fallback_py(name, value, extra)?; + infer_to_python(value, include, exclude, extra) +} - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(value.py()) { +fn json_key( + key: &Bound<'_, PyAny>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, + retry_with_lax_check: bool, +) -> PyResult> { + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + + for comb_serializer in choices.clone() { + match comb_serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(key.py()) { + true => (), + false => errors.push(err), + }, + } + } + + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { + match comb_serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(key.py()) { true => (), false => errors.push(err), }, } } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), - }, - } - } - } - - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - - extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; - infer_serialize(value, serializer, include, exclude, extra) } - fn get_name(&self) -> &str { - &self.name + for err in &errors { + extra.warnings.custom_warning(err.to_string()); } - fn retry_with_lax_check(&self) -> bool { - self.choices.iter().any(CombinedSerializer::retry_with_lax_check) - } + extra.warnings.on_fallback_py(name, key, extra)?; + infer_json_key(key, extra) } -fn to_python( +fn serde_serialize( value: &Bound<'_, PyAny>, + serializer: S, include: Option<&Bound<'_, PyAny>>, exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, choices: &[CombinedSerializer], name: &str, -) -> PyResult { - // try the serializers in left to right order with error_on fallback=true + retry_with_lax_check: bool, +) -> Result { + let py = value.py(); let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in choices.clone() { match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { true => (), false => errors.push(err), }, } } - let retry_with_lax_check = choices.clone().into_iter().any(CombinedSerializer::retry_with_lax_check); if retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { true => (), false => errors.push(err), }, @@ -211,8 +201,60 @@ fn to_python( extra.warnings.custom_warning(err.to_string()); } - extra.warnings.on_fallback_py(name, value, extra)?; - infer_to_python(value, include, exclude, extra) + extra.warnings.on_fallback_ser::(name, value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) +} + +impl TypeSerializer for UnionSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + to_python( + value, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + json_key(key, extra, &self.choices, self.get_name(), self.retry_with_lax_check()) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + serde_serialize( + value, + serializer, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) + } + + fn get_name(&self) -> &str { + &self.name + } + + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) + } } #[derive(Debug)] @@ -294,7 +336,15 @@ impl TypeSerializer for TaggedUnionSerializer { } } - to_python(value, include, exclude, extra, &self.choices, self.get_name()) + to_python( + value, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -316,13 +366,7 @@ impl TypeSerializer for TaggedUnionSerializer { } } - let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); - if let Ok(s) = basic_union_ser { - return s.json_key(key, extra); - } - - extra.warnings.on_fallback_py(self.get_name(), key, extra)?; - infer_json_key(key, extra) + json_key(key, extra, &self.choices, self.get_name(), self.retry_with_lax_check()) } fn serde_serialize( @@ -355,13 +399,16 @@ impl TypeSerializer for TaggedUnionSerializer { } } - let basic_union_ser = UnionSerializer::from_choices(self.choices.clone()); - if let Ok(s) = basic_union_ser { - return s.serde_serialize(value, serializer, include, exclude, extra); - } - - extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; - infer_serialize(value, serializer, include, exclude, extra) + serde_serialize( + value, + serializer, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) } fn get_name(&self) -> &str { @@ -376,18 +423,16 @@ impl TypeSerializer for TaggedUnionSerializer { impl TaggedUnionSerializer { fn get_discriminator_value(&self, value: &Bound<'_, PyAny>) -> Option> { let py = value.py(); - match &self.discriminator { + let discriminator_value = match &self.discriminator { Discriminator::LookupKey(lookup_key) => match lookup_key { LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok().map(|obj| obj.to_object(py)), _ => None, }, - Discriminator::Function(func) => func.call1(py, (value,)).ok().or_else(|| { - // Try converting object to a dict, might be more compatible with poorly defined callable discriminator - value - .call_method0(intern!(py, "dict")) - .and_then(|v| func.call1(py, (v.to_object(py),))) - .ok() - }), + Discriminator::Function(func) => func.call1(py, (value,)).ok(), + }; + if discriminator_value.is_none() { + // warn if the discriminator value is not found } + return discriminator_value; } } From c32560952c5422c727c031d06e56f436461b1c45 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 15 Aug 2024 10:20:20 -0400 Subject: [PATCH 14/15] final cleanup, thanks @davidhewitt --- src/serializers/type_serializers/union.rs | 82 ++++++++++++++--------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index a1d2daa2b..35dc083e8 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -86,7 +86,7 @@ fn to_python( new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - for comb_serializer in choices.clone() { + for comb_serializer in choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(value.py()) { @@ -117,18 +117,18 @@ fn to_python( infer_to_python(value, include, exclude, extra) } -fn json_key( - key: &Bound<'_, PyAny>, +fn json_key<'a>( + key: &'a Bound<'_, PyAny>, extra: &Extra, choices: &[CombinedSerializer], name: &str, retry_with_lax_check: bool, -) -> PyResult> { +) -> PyResult> { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - for comb_serializer in choices.clone() { + for comb_serializer in choices { match comb_serializer.json_key(key, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(key.py()) { @@ -159,6 +159,7 @@ fn json_key( infer_json_key(key, extra) } +#[allow(clippy::too_many_arguments)] fn serde_serialize( value: &Bound<'_, PyAny>, serializer: S, @@ -174,7 +175,7 @@ fn serde_serialize( new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - for comb_serializer in choices.clone() { + for comb_serializer in choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), Err(err) => match err.is_instance_of::(py) { @@ -303,7 +304,7 @@ impl BuildSerializer for TaggedUnionSerializer { } } -impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, lookup }); +impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, choices }); impl TypeSerializer for TaggedUnionSerializer { fn to_python( @@ -318,16 +319,18 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - if let Some(tag) = self.get_discriminator_value(value) { + if let Some(tag) = self.get_discriminator_value(value, extra) { let tag_str = tag.to_string(); - if let Some(serializer) = self.lookup.get(&tag_str) { - match self.choices[*serializer].to_python(value, include, exclude, &new_extra) { + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let serializer = &self.choices[serializer_index]; + + match serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(py) { true => { if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; - return self.choices[*serializer].to_python(value, include, exclude, &new_extra); + return serializer.to_python(value, include, exclude, &new_extra); } } false => return Err(err), @@ -348,20 +351,26 @@ impl TypeSerializer for TaggedUnionSerializer { } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + let py = key.py(); let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - if let Some(tag) = self.get_discriminator_value(key) { + if let Some(tag) = self.get_discriminator_value(key, extra) { let tag_str = tag.to_string(); - if let Some(serializer) = self.lookup.get(&tag_str) { - match self.choices[*serializer].json_key(key, &new_extra) { + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let serializer = &self.choices[serializer_index]; + + match serializer.json_key(key, &new_extra) { Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - return self.choices[*serializer].json_key(key, &new_extra); + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.json_key(key, &new_extra); + } } - } + false => return Err(err), + }, } } } @@ -381,20 +390,25 @@ impl TypeSerializer for TaggedUnionSerializer { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - if let Some(tag) = self.get_discriminator_value(value) { + if let Some(tag) = self.get_discriminator_value(value, extra) { let tag_str = tag.to_string(); - if let Some(selected_serializer) = self.lookup.get(&tag_str) { - match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) { + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let selected_serializer = &self.choices[serializer_index]; + + match selected_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => return Err(py_err_se_err(err)), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => return Err(py_err_se_err(err)), + } } } - } + false => return Err(py_err_se_err(err)), + }, } } } @@ -421,7 +435,7 @@ impl TypeSerializer for TaggedUnionSerializer { } impl TaggedUnionSerializer { - fn get_discriminator_value(&self, value: &Bound<'_, PyAny>) -> Option> { + fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option> { let py = value.py(); let discriminator_value = match &self.discriminator { Discriminator::LookupKey(lookup_key) => match lookup_key { @@ -431,8 +445,12 @@ impl TaggedUnionSerializer { Discriminator::Function(func) => func.call1(py, (value,)).ok(), }; if discriminator_value.is_none() { - // warn if the discriminator value is not found + extra.warnings.custom_warning( + format!( + "Failed to get discriminator value for tagged union serialization for {value} - defaulting to left to right union serialization." + ) + ); } - return discriminator_value; + discriminator_value } } From f1236fffe25a33535ce3078cd7eb425d020e9eda Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 15 Aug 2024 10:59:03 -0400 Subject: [PATCH 15/15] limit value_str --- src/serializers/type_serializers/union.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 35dc083e8..23aa874fa 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -8,9 +8,10 @@ use std::borrow::Cow; use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; +use crate::errors::write_truncated_to_50_bytes; use crate::lookup_key::LookupKey; use crate::serializers::type_serializers::py_err_se_err; -use crate::tools::SchemaDict; +use crate::tools::{safe_repr, SchemaDict}; use crate::PydanticSerializationUnexpectedValue; use super::{ @@ -445,9 +446,15 @@ impl TaggedUnionSerializer { Discriminator::Function(func) => func.call1(py, (value,)).ok(), }; if discriminator_value.is_none() { + let input_str = safe_repr(value); + let mut value_str = String::with_capacity(100); + value_str.push_str("with value `"); + write_truncated_to_50_bytes(&mut value_str, input_str.to_cow()).expect("Writing to a `String` failed"); + value_str.push('`'); + extra.warnings.custom_warning( format!( - "Failed to get discriminator value for tagged union serialization for {value} - defaulting to left to right union serialization." + "Failed to get discriminator value for tagged union serialization {value_str} - defaulting to left to right union serialization." ) ); }