diff --git a/CHANGELOG.md b/CHANGELOG.md index b5dfcec..43a816e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 otherwise (i.e.: `collections/{collectionID}/items/{itemID}`). ### Fixed -- n/a +- Fix the validation strategy of the `mlm:model` role required by at least one Asset under a STAC Item. + Although the role requirement was validated, the definition did not allow for other Assets without it to exist. ## [v1.1.0](https://github.com/crim-ca/mlm-extension/tree/v1.1.0) diff --git a/json-schema/schema.json b/json-schema/schema.json index 2a60b7a..21cd5f2 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -8,6 +8,7 @@ "$comment": "This is the schema for STAC extension MLM in Items.", "allOf": [ { + "$comment": "Schema to validate the MLM fields under Item properties or Assets properties.", "type": "object", "required": [ "type", @@ -40,10 +41,6 @@ "allOf": [ { "$ref": "#/$defs/fields" - }, - { - "$comment": "At least one Asset must provide the model definition.", - "$ref": "#/$defs/AssetModelRole" } ] } @@ -52,6 +49,10 @@ }, { "$ref": "#/$defs/stac_extensions_mlm" + }, + { + "$comment": "Schema to validate model role requirement.", + "$ref": "#/$defs/AssetModelRoleMinimumOneDefinition" } ] }, @@ -652,6 +653,52 @@ "DataType": { "$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/data_type" }, + "AssetModelRoleMinimumOneDefinition": { + "$comment": "At least one Asset must provide the model definition indicated by the 'mlm:model' role.", + "required": [ + "assets" + ], + "anyOf": [ + { + "properties": { + "assets": { + "additionalProperties": { + "properties": { + "roles": { + "type": "array", + "items": { + "const": "mlm:model" + }, + "minItems": 1 + } + } + } + } + } + }, + { + "not": { + "properties": { + "assets": { + "additionalProperties": { + "properties": { + "roles": { + "type": "array", + "items": { + "type": "string", + "not": { + "const": "mlm:model" + } + } + } + } + } + } + } + } + } + ] + }, "AssetModelRole": { "required": ["roles"], "properties": { diff --git a/tests/test_schema.py b/tests/test_schema.py index 3b26a3f..5cf04d5 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -66,6 +66,61 @@ def test_mlm_no_input_allowed_but_explicit_empty_array_required( pystac.validation.validate(mlm_item, validator=mlm_validator) +@pytest.mark.parametrize( + "mlm_example", + ["item_basic.json"], + indirect=True, +) +def test_mlm_other_non_mlm_assets_allowed( + mlm_validator: STACValidator, + mlm_example: Dict[str, JSON], +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand + + mlm_data["assets"]["sample"] = { # type: ignore + "type": "image/jpeg", + "href": "https://example.com/sample/output.jpg", + "roles": ["preview"], + "title": "Model Output Predictions Sample", + } + mlm_data["assets"]["model-cart"] = { # type: ignore + "type": "text/markdown", + "href": "https://example.com/sample/model.md", + "roles": ["metadata"], + "title": "Model Cart", + } + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # still valid + + +@pytest.mark.parametrize( + "mlm_example", + ["item_basic.json"], + indirect=True, +) +def test_mlm_at_least_one_asset_model( + mlm_validator: STACValidator, + mlm_example: Dict[str, JSON], +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand + + mlm_data["assets"] = { # needs at least 1 asset with role 'mlm:model' + "model": { + "type": "application/octet-stream; application=pytorch", + "href": "https://example.com/sample/checkpoint.pt", + "roles": ["checkpoint"], + "title": "Model Weights Checkpoint", + } + } + with pytest.raises(pystac.errors.STACValidationError): + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) + + def test_model_metadata_to_dict(eurosat_resnet): assert eurosat_resnet.item.to_dict()