Skip to content

Commit

Permalink
readonly_session("main", as_of=...) to open a branch at a timestamp (
Browse files Browse the repository at this point in the history
  • Loading branch information
paraseba authored Feb 23, 2025
1 parent 474e004 commit a0be3c8
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 15 deletions.
1 change: 1 addition & 0 deletions icechunk-python/python/icechunk/_icechunk_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ class PyRepository:
*,
tag: str | None = None,
snapshot_id: str | None = None,
as_of: datetime.datetime | None = None,
) -> PySession: ...
def writable_session(self, branch: str) -> PySession: ...
def expire_snapshots(
Expand Down
6 changes: 5 additions & 1 deletion icechunk-python/python/icechunk/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def readonly_session(
*,
tag: str | None = None,
snapshot_id: str | None = None,
as_of: datetime.datetime | None = None,
) -> Session:
"""
Create a read-only session.
Expand All @@ -453,6 +454,9 @@ def readonly_session(
If provided, the tag to create the session on.
snapshot_id : str, optional
If provided, the snapshot ID to create the session on.
as_of: datetime.datetime, optional
When combined with the branch argument, it will open the session at the last
snapshot that is at or before this datetime
Returns
-------
Expand All @@ -465,7 +469,7 @@ def readonly_session(
"""
return Session(
self._repository.readonly_session(
branch=branch, tag=tag, snapshot_id=snapshot_id
branch=branch, tag=tag, snapshot_id=snapshot_id, as_of=as_of
)
)

Expand Down
26 changes: 19 additions & 7 deletions icechunk-python/src/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ impl PyRepository {
let repo = Arc::clone(&self.0);
// This function calls block_on, so we need to allow other thread python to make progress
py.allow_threads(move || {
let version = args_to_version_info(branch, tag, snapshot_id)?;
let version = args_to_version_info(branch, tag, snapshot_id, None)?;
let ancestry = pyo3_async_runtimes::tokio::get_runtime()
.block_on(async move { repo.ancestry_arc(&version).await })
.map_err(PyIcechunkStoreError::RepositoryError)?
Expand Down Expand Up @@ -701,8 +701,8 @@ impl PyRepository {
to_tag: Option<String>,
to_snapshot_id: Option<String>,
) -> PyResult<PyDiff> {
let from = args_to_version_info(from_branch, from_tag, from_snapshot_id)?;
let to = args_to_version_info(to_branch, to_tag, to_snapshot_id)?;
let from = args_to_version_info(from_branch, from_tag, from_snapshot_id, None)?;
let to = args_to_version_info(to_branch, to_tag, to_snapshot_id, None)?;

// This function calls block_on, so we need to allow other thread python to make progress
py.allow_threads(move || {
Expand All @@ -717,17 +717,18 @@ impl PyRepository {
})
}

#[pyo3(signature = (*, branch = None, tag = None, snapshot_id = None))]
#[pyo3(signature = (*, branch = None, tag = None, snapshot_id = None, as_of = None))]
pub fn readonly_session(
&self,
py: Python<'_>,
branch: Option<String>,
tag: Option<String>,
snapshot_id: Option<String>,
as_of: Option<DateTime<Utc>>,
) -> PyResult<PySession> {
// This function calls block_on, so we need to allow other thread python to make progress
py.allow_threads(move || {
let version = args_to_version_info(branch, tag, snapshot_id)?;
let version = args_to_version_info(branch, tag, snapshot_id, as_of)?;
let session =
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
Expand Down Expand Up @@ -841,6 +842,7 @@ fn args_to_version_info(
branch: Option<String>,
tag: Option<String>,
snapshot: Option<String>,
as_of: Option<DateTime<Utc>>,
) -> PyResult<VersionInfo> {
let n = [&branch, &tag, &snapshot].iter().filter(|r| !r.is_none()).count();
if n > 1 {
Expand All @@ -849,8 +851,18 @@ fn args_to_version_info(
));
}

if let Some(branch_name) = branch {
Ok(VersionInfo::BranchTipRef(branch_name))
if as_of.is_some() && branch.is_none() {
return Err(PyValueError::new_err(
"as_of argument must be provided together with a branch name",
));
}

if let Some(branch) = branch {
if let Some(at) = as_of {
Ok(VersionInfo::AsOf { branch, at })
} else {
Ok(VersionInfo::BranchTipRef(branch))
}
} else if let Some(tag_name) = tag {
Ok(VersionInfo::TagRef(tag_name))
} else if let Some(snapshot_id) = snapshot {
Expand Down
35 changes: 35 additions & 0 deletions icechunk-python/tests/test_timetravel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,38 @@ async def test_tag_delete() -> None:

with pytest.raises(ValueError):
repo.create_tag("tag", snap)


async def test_session_with_as_of() -> None:
repo = ic.Repository.create(
storage=ic.in_memory_storage(),
)

session = repo.writable_session("main")
store = session.store

times = []
group = zarr.group(store=store, overwrite=True)
sid = session.commit("root")
times.append(next(repo.ancestry(snapshot_id=sid)).written_at)

for i in range(5):
session = repo.writable_session("main")
store = session.store
group = zarr.open_group(store=store)
group.create_group(f"child {i}")
sid = session.commit(f"child {i}")
times.append(next(repo.ancestry(snapshot_id=sid)).written_at)

ancestry = list(p for p in repo.ancestry(branch="main"))
assert len(ancestry) == 7 # initial + root + 5 children

store = repo.readonly_session("main", as_of=times[-1]).store
group = zarr.open_group(store=store, mode="r")

for i, time in enumerate(times):
store = repo.readonly_session("main", as_of=time).store
group = zarr.open_group(store=store, mode="r")
expected_children = {f"child {j}" for j in range(i)}
actual_children = {g[0] for g in group.members()}
assert expected_children == actual_children
35 changes: 28 additions & 7 deletions icechunk/src/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::{
sync::Arc,
};

use async_recursion::async_recursion;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use err_into::ErrorInto as _;
use futures::{
stream::{FuturesOrdered, FuturesUnordered},
Expand Down Expand Up @@ -37,15 +39,13 @@ use crate::{
Storage, StorageError,
};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum VersionInfo {
#[serde(rename = "snapshot_id")]
SnapshotId(SnapshotId),
#[serde(rename = "tag")]
TagRef(String),
#[serde(rename = "branch")]
BranchTipRef(String),
AsOf { branch: String, at: DateTime<Utc> },
}

#[derive(Debug, Error)]
Expand All @@ -60,6 +60,8 @@ pub enum RepositoryErrorKind {

#[error("snapshot not found: `{id}`")]
SnapshotNotFound { id: SnapshotId },
#[error("branch {branch} does not have a snapshots before or at {at}")]
InvalidAsOfSpec { branch: String, at: DateTime<Utc> },
#[error("invalid snapshot id: `{0}`")]
InvalidSnapshotId(String),
#[error("tag error: `{0}`")]
Expand Down Expand Up @@ -404,11 +406,12 @@ impl Repository {
}

/// Returns the sequence of parents of the snapshot pointed by the given version
#[async_recursion(?Send)]
#[instrument(skip(self))]
pub async fn ancestry(
&self,
pub async fn ancestry<'a>(
&'a self,
version: &VersionInfo,
) -> RepositoryResult<impl Stream<Item = RepositoryResult<SnapshotInfo>> + '_> {
) -> RepositoryResult<impl Stream<Item = RepositoryResult<SnapshotInfo>> + 'a> {
let snapshot_id = self.resolve_version(version).await?;
self.snapshot_ancestry(&snapshot_id).await
}
Expand Down Expand Up @@ -572,6 +575,24 @@ impl Repository {
.await?;
Ok(ref_data.snapshot)
}
VersionInfo::AsOf { branch, at } => {
let tip = VersionInfo::BranchTipRef(branch.clone());
let snap = self
.ancestry(&tip)
.await?
.try_skip_while(|parent| ready(Ok(&parent.flushed_at > at)))
.take(1)
.try_collect::<Vec<_>>()
.await?;
match snap.into_iter().next() {
Some(snap) => Ok(snap.id),
None => Err(RepositoryErrorKind::InvalidAsOfSpec {
branch: branch.clone(),
at: *at,
}
.into()),
}
}
}
}

Expand Down

0 comments on commit a0be3c8

Please sign in to comment.