diff --git a/icechunk/src/zarr.rs b/icechunk/src/zarr.rs index f9e871713..28b757446 100644 --- a/icechunk/src/zarr.rs +++ b/icechunk/src/zarr.rs @@ -12,7 +12,7 @@ use std::{ use bytes::Bytes; use futures::{Stream, StreamExt, TryStreamExt}; use itertools::Itertools; -use serde::{Deserialize, Serialize}; +use serde::{de, Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, TryFromInto}; use thiserror::Error; @@ -612,6 +612,7 @@ impl Key { #[derive(Debug, Serialize, Deserialize)] struct ArrayMetadata { zarr_format: u8, + #[serde(deserialize_with = "validate_array_node_type")] node_type: String, attributes: Option, #[serde(flatten)] @@ -720,10 +721,43 @@ impl From for ZarrArrayMetadataSerialzer { #[derive(Debug, Serialize, Deserialize)] struct GroupMetadata { zarr_format: u8, + #[serde(deserialize_with = "validate_group_node_type")] node_type: String, attributes: Option, } +fn validate_group_node_type<'de, D>(d: D) -> Result +where + D: de::Deserializer<'de>, +{ + let value = String::deserialize(d)?; + + if value != "group" { + return Err(de::Error::invalid_value( + de::Unexpected::Str(value.as_str()), + &"the word 'group'", + )); + } + + Ok(value) +} + +fn validate_array_node_type<'de, D>(d: D) -> Result +where + D: de::Deserializer<'de>, +{ + let value = String::deserialize(d)?; + + if value != "array" { + return Err(de::Error::invalid_value( + de::Unexpected::Str(value.as_str()), + &"the word 'array'", + )); + } + + Ok(value) +} + impl ArrayMetadata { fn new(attributes: Option, zarr_metadata: ZarrArrayMetadata) -> Self { Self { zarr_format: 3, node_type: "array".to_string(), attributes, zarr_metadata } @@ -959,6 +993,27 @@ mod tests { ); } + #[test] + fn test_metadata_deserialization() { + assert!(serde_json::from_str::( + r#"{"zarr_format":3, "node_type":"group"}"# + ) + .is_ok()); + assert!(serde_json::from_str::( + r#"{"zarr_format":3, "node_type":"array"}"# + ) + .is_err()); + + assert!(serde_json::from_str::( + r#"{"zarr_format":3,"node_type":"array","shape":[2,2,2],"data_type":"int32","chunk_grid":{"name":"regular","configuration":{"chunk_shape":[1,1,1]}},"chunk_key_encoding":{"name":"default","configuration":{"separator":"/"}},"fill_value":0,"codecs":[{"name":"mycodec","configuration":{"foo":42}}],"storage_transformers":[{"name":"mytransformer","configuration":{"bar":43}}],"dimension_names":["x","y","t"]}"# + ) + .is_ok()); + assert!(serde_json::from_str::( + r#"{"zarr_format":3,"node_type":"group","shape":[2,2,2],"data_type":"int32","chunk_grid":{"name":"regular","configuration":{"chunk_shape":[1,1,1]}},"chunk_key_encoding":{"name":"default","configuration":{"separator":"/"}},"fill_value":0,"codecs":[{"name":"mycodec","configuration":{"foo":42}}],"storage_transformers":[{"name":"mytransformer","configuration":{"bar":43}}],"dimension_names":["x","y","t"]}"# + ) + .is_err()); + } + #[tokio::test] async fn test_metadata_set_and_get() -> Result<(), Box> { let storage: Arc =