diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index 791de9d92..9884c79ed 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -124,7 +124,7 @@ class ErrorTypeInfo(_TypedDict): """Example of context values.""" -class MultiHostHost(_TypedDict): +class MultiHostHost(_TypedDict, total=False): """ A host part of a multi-host URL. """ diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index d696ac407..fabd92b3a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -583,7 +583,7 @@ class Url(SupportsAllComparisons): scheme: str, username: str | None = None, password: str | None = None, - host: str, + host: str | None = None, port: int | None = None, path: str | None = None, query: str | None = None, @@ -596,7 +596,7 @@ class Url(SupportsAllComparisons): scheme: The scheme part of the URL. username: The username part of the URL, or omit for no username. password: The password part of the URL, or omit for no password. - host: The host part of the URL. + host: The host part of the URL, or omit for no host. port: The port part of the URL, or omit for no port. path: The path part of the URL, or omit for no path. query: The query part of the URL, or omit for no query. diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 5a8646fca..23e384af1 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3655,6 +3655,7 @@ class MyModel: class UrlSchema(TypedDict, total=False): type: Required[Literal['url']] + cls: Type[Any] max_length: int allowed_schemes: List[str] host_required: bool # default False @@ -3669,6 +3670,7 @@ class UrlSchema(TypedDict, total=False): def url_schema( *, + cls: Type[Any] | None = None, max_length: int | None = None, allowed_schemes: list[str] | None = None, host_required: bool | None = None, @@ -3693,6 +3695,7 @@ def url_schema( ``` Args: + cls: The class to use for the URL build (a subclass of `pydantic_core.Url`) max_length: The maximum length of the URL allowed_schemes: The allowed URL schemes host_required: Whether the URL must have a host @@ -3706,6 +3709,7 @@ def url_schema( """ return _dict_not_none( type='url', + cls=cls, max_length=max_length, allowed_schemes=allowed_schemes, host_required=host_required, @@ -3721,6 +3725,7 @@ def url_schema( class MultiHostUrlSchema(TypedDict, total=False): type: Required[Literal['multi-host-url']] + cls: Type[Any] max_length: int allowed_schemes: List[str] host_required: bool # default False @@ -3735,6 +3740,7 @@ class MultiHostUrlSchema(TypedDict, total=False): def multi_host_url_schema( *, + cls: Type[Any] | None = None, max_length: int | None = None, allowed_schemes: list[str] | None = None, host_required: bool | None = None, @@ -3759,6 +3765,7 @@ def multi_host_url_schema( ``` Args: + cls: The class to use for the URL build (a subclass of `pydantic_core.MultiHostUrl`) max_length: The maximum length of the URL allowed_schemes: The allowed URL schemes host_required: Whether the URL must have a host @@ -3772,6 +3779,7 @@ def multi_host_url_schema( """ return _dict_not_none( type='multi-host-url', + cls=cls, max_length=max_length, allowed_schemes=allowed_schemes, host_required=host_required, diff --git a/src/url.rs b/src/url.rs index 881347f25..faaab9c63 100644 --- a/src/url.rs +++ b/src/url.rs @@ -156,12 +156,12 @@ impl PyUrl { } #[classmethod] - #[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))] + #[pyo3(signature=(*, scheme, host=None, username=None, password=None, port=None, path=None, query=None, fragment=None))] #[allow(clippy::too_many_arguments)] pub fn build<'py>( cls: &Bound<'py, PyType>, scheme: &str, - host: &str, + host: Option<&str>, username: Option<&str>, password: Option<&str>, port: Option, @@ -172,7 +172,7 @@ impl PyUrl { let url_host = UrlHostParts { username: username.map(Into::into), password: password.map(Into::into), - host: Some(host.into()), + host: host.map(Into::into), port, }; let mut url = format!("{scheme}://{url_host}"); @@ -423,6 +423,7 @@ impl PyMultiHostUrl { } } +#[cfg_attr(debug_assertions, derive(Debug))] pub struct UrlHostParts { username: Option, password: Option, @@ -440,11 +441,12 @@ impl FromPyObject<'_> for UrlHostParts { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { let py = ob.py(); let dict = ob.downcast::()?; + Ok(UrlHostParts { - username: dict.get_as(intern!(py, "username"))?, - password: dict.get_as(intern!(py, "password"))?, - host: dict.get_as(intern!(py, "host"))?, - port: dict.get_as(intern!(py, "port"))?, + username: dict.get_as::>(intern!(py, "username"))?.flatten(), + password: dict.get_as::>(intern!(py, "password"))?.flatten(), + host: dict.get_as::>(intern!(py, "host"))?.flatten(), + port: dict.get_as::>(intern!(py, "port"))?.flatten(), }) } } diff --git a/src/validators/url.rs b/src/validators/url.rs index 46bab20c1..92c482b46 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -4,7 +4,7 @@ use std::str::Chars; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; +use pyo3::types::{PyDict, PyList, PyType}; use ahash::AHashSet; use url::{ParseError, SyntaxViolation, Url}; @@ -26,6 +26,7 @@ type AllowedSchemas = Option<(AHashSet, String)>; #[derive(Debug, Clone)] pub struct UrlValidator { strict: bool, + cls: Option>, max_length: Option, allowed_schemes: AllowedSchemas, host_required: bool, @@ -47,6 +48,7 @@ impl BuildValidator for UrlValidator { Ok(Self { strict: is_strict(schema, config)?, + cls: schema.get_as(intern!(schema.py(), "cls"))?, max_length: schema.get_as(intern!(schema.py(), "max_length"))?, host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false), default_host: schema.get_as(intern!(schema.py(), "default_host"))?, @@ -59,7 +61,7 @@ impl BuildValidator for UrlValidator { } } -impl_py_gc_traverse!(UrlValidator {}); +impl_py_gc_traverse!(UrlValidator { cls }); impl Validator for UrlValidator { fn validate<'py>( @@ -93,7 +95,31 @@ impl Validator for UrlValidator { Ok(()) => { // Lax rather than strict to preserve V2.4 semantic that str wins over url in union state.floor_exactness(Exactness::Lax); - Ok(either_url.into_py(py)) + + if let Some(url_subclass) = &self.cls { + // TODO: we do an extra build for a subclass here, we should avoid this + // in v2.11 for perf reasons, but this is a worthwhile patch for now + // given that we want isinstance to work properly for subclasses of Url + let py_url = match either_url { + EitherUrl::Py(py_url) => py_url.get().clone(), + EitherUrl::Rust(rust_url) => PyUrl::new(rust_url), + }; + + let py_url = PyUrl::build( + url_subclass.bind(py), + py_url.scheme(), + py_url.host(), + py_url.username(), + py_url.password(), + py_url.port(), + py_url.path().filter(|path| *path != "/"), + py_url.query(), + py_url.fragment(), + )?; + Ok(py_url.into_py(py)) + } else { + Ok(either_url.into_py(py)) + } } Err(error_type) => Err(ValError::new(error_type, input)), } @@ -186,6 +212,7 @@ impl CopyFromPyUrl for EitherUrl<'_> { #[derive(Debug, Clone)] pub struct MultiHostUrlValidator { strict: bool, + cls: Option>, max_length: Option, allowed_schemes: AllowedSchemas, host_required: bool, @@ -213,6 +240,7 @@ impl BuildValidator for MultiHostUrlValidator { } Ok(Self { strict: is_strict(schema, config)?, + cls: schema.get_as(intern!(schema.py(), "cls"))?, max_length: schema.get_as(intern!(schema.py(), "max_length"))?, allowed_schemes, host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false), @@ -225,7 +253,7 @@ impl BuildValidator for MultiHostUrlValidator { } } -impl_py_gc_traverse!(MultiHostUrlValidator {}); +impl_py_gc_traverse!(MultiHostUrlValidator { cls }); impl Validator for MultiHostUrlValidator { fn validate<'py>( @@ -258,7 +286,38 @@ impl Validator for MultiHostUrlValidator { Ok(()) => { // Lax rather than strict to preserve V2.4 semantic that str wins over url in union state.floor_exactness(Exactness::Lax); - Ok(multi_url.into_py(py)) + + if let Some(url_subclass) = &self.cls { + // TODO: we do an extra build for a subclass here, we should avoid this + // in v2.11 for perf reasons, but this is a worthwhile patch for now + // given that we want isinstance to work properly for subclasses of Url + let py_url = match multi_url { + EitherMultiHostUrl::Py(py_url) => py_url.get().clone(), + EitherMultiHostUrl::Rust(rust_url) => rust_url, + }; + + let hosts = py_url + .hosts(py)? + .into_iter() + .map(|host| host.extract().expect("host should be a valid UrlHostParts")) + .collect(); + + let py_url = PyMultiHostUrl::build( + url_subclass.bind(py), + py_url.scheme(), + Some(hosts), + py_url.path().filter(|path| *path != "/"), + py_url.query(), + py_url.fragment(), + None, + None, + None, + None, + )?; + Ok(py_url.into_py(py)) + } else { + Ok(multi_url.into_py(py)) + } } Err(error_type) => Err(ValError::new(error_type, input)), } diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 59489dd00..112046090 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -1305,3 +1305,19 @@ def test_url_build() -> None: ) assert url == Url('postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test') assert str(url) == 'postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test' + + +def test_url_subclass() -> None: + class UrlSubclass(Url): + pass + + validator = SchemaValidator(core_schema.url_schema(cls=UrlSubclass)) + assert isinstance(validator.validate_python('http://example.com'), UrlSubclass) + + +def test_multi_host_url_subclass() -> None: + class MultiHostUrlSubclass(MultiHostUrl): + pass + + validator = SchemaValidator(core_schema.multi_host_url_schema(cls=MultiHostUrlSubclass)) + assert isinstance(validator.validate_python('http://example.com'), MultiHostUrlSubclass)