Skip to content

Commit

Permalink
test: Fix tests (#20745)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 16, 2025
1 parent ab8ff43 commit f56affc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
13 changes: 12 additions & 1 deletion crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use regex::Regex;
#[cfg(feature = "http")]
use reqwest::header::HeaderMap;
#[cfg(feature = "serde")]
use serde::Deserializer;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "cloud")]
use url::Url;
Expand Down Expand Up @@ -78,10 +80,19 @@ pub struct CloudOptions {
pub file_cache_ttl: u64,
pub(crate) config: Option<CloudConfig>,
#[cfg(feature = "cloud")]
#[cfg_attr(feature = "serde", serde(skip))] // skipped for polars-cloud
#[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_or_default"))]
pub(crate) credential_provider: Option<PlCredentialProvider>,
}

#[cfg(all(feature = "serde", feature = "cloud"))]
fn deserialize_or_default<'de, D>(deserializer: D) -> Result<Option<PlCredentialProvider>, D::Error>
where
D: Deserializer<'de>,
{
type T = Option<PlCredentialProvider>;
T::deserialize(deserializer).or_else(|_| Ok(Default::default()))
}

impl Default for CloudOptions {
fn default() -> Self {
Self::default_static_ref().clone()
Expand Down
25 changes: 0 additions & 25 deletions py-polars/tests/unit/io/cloud/test_credential_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import sys
from typing import Any

import pytest
Expand Down Expand Up @@ -70,30 +69,6 @@ def __call__(self) -> pl.CredentialProviderFunctionReturn:
lf.collect()


def test_scan_credential_provider_serialization_pyversion() -> None:
lf = pl.scan_parquet(
"s3://bucket/path", credential_provider=pl.CredentialProviderAWS()
)

serialized = lf.serialize()
serialized = bytearray(serialized)

# We can't monkeypatch sys.python_version so we just mutate the output
# instead.

v = b"PLPYFN"
i = serialized.index(v) + len(v)
a, b = serialized[i:][:2]
serialized_pyver = (a, b)
assert serialized_pyver == (sys.version_info.minor, sys.version_info.micro)
# Note: These are loaded as u8's
serialized[i] = 255
serialized[i + 1] = 254

with pytest.raises(ComputeError, match=r"python version.*(3, 255, 254).*differs.*"):
lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))


def test_credential_provider_skips_config_autoload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
Expand Down

0 comments on commit f56affc

Please sign in to comment.