Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tagged union serializer 🚀 #1397

Merged
merged 17 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod union;
43 changes: 43 additions & 0 deletions src/common/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use pyo3::prelude::*;
use pyo3::{PyTraverseError, PyVisit};

use crate::lookup_key::LookupKey;
use crate::py_gc::PyGcTraverse;

#[derive(Debug, Clone)]
pub enum Discriminator {
/// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases
LookupKey(LookupKey),
/// call a function to find the tag to use
Function(PyObject),
}

impl Discriminator {
pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult<Self> {
if raw.is_callable() {
return Ok(Self::Function(raw.to_object(py)));
}

let lookup_key = LookupKey::from_py(py, raw, None)?;
Ok(Self::LookupKey(lookup_key))
}

pub fn to_string_py(&self, py: Python) -> PyResult<String> {
match self {
Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)),
Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()),
}
}
}

impl PyGcTraverse for Discriminator {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Self::Function(obj) => visit.call(obj)?,
Self::LookupKey(_) => {}
}
Ok(())
}
}

pub(crate) const SMALL_UNION_THRESHOLD: usize = 4;
4 changes: 0 additions & 4 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ pub trait ValidatedDict<'py> {
where
Self: 'a;
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>>;
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>>;
// FIXME this is a bit of a leaky abstraction
fn is_py_get_attr(&self) -> bool {
false
Expand Down Expand Up @@ -280,9 +279,6 @@ impl<'py> ValidatedDict<'py> for Never {
fn get_item<'k>(&self, _key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
unreachable!()
}
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
unreachable!()
}
fn iterate<'a, R>(
&'a self,
_consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
4 changes: 0 additions & 4 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,6 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> {
key.json_get(self)
}

fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
None
}

fn iterate<'a, R>(
&'a self,
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
7 changes: 0 additions & 7 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,13 +775,6 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> {
matches!(self, Self::GetAttr(..))
}

fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
match self {
Self::Dict(dict) => Some(dict),
_ => None,
}
}

fn iterate<'a, R>(
&'a self,
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
3 changes: 0 additions & 3 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,6 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> {
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
key.py_get_string_mapping_item(&self.0)
}
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
None
}
fn iterate<'a, R>(
&'a self,
consumer: impl super::ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod py_gc;

mod argument_markers;
mod build_tools;
mod common;
mod definitions;
mod errors;
mod input;
Expand Down
3 changes: 2 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ combined_serializer! {
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
find_only: {
super::type_serializers::union::TaggedUnionBuilder;
super::type_serializers::other::ChainBuilder;
super::type_serializers::other::CustomErrorBuilder;
super::type_serializers::other::CallBuilder;
Expand Down Expand Up @@ -138,6 +137,7 @@ combined_serializer! {
Json: super::type_serializers::json::JsonSerializer;
JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer;
Union: super::type_serializers::union::UnionSerializer;
TaggedUnion: super::type_serializers::union::TaggedUnionSerializer;
Literal: super::type_serializers::literal::LiteralSerializer;
Enum: super::type_serializers::enum_::EnumSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Expand Down Expand Up @@ -246,6 +246,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TaggedUnion(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
Expand Down
Loading
Loading