Skip to content

Commit

Permalink
fix requirement of mlm:model role, required only by at least one asset
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Apr 27, 2024
1 parent 6d9943b commit 28c9a81
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- n/a

### Fixed
- n/a
- Fix the validation strategy of the `mlm:model` role required by at least one Asset under a STAC Item.
Although the role requirement was validated, the definition did not allow for other Assets without it to exist.

## [v1.1.0](https://github.com/crim-ca/mlm-extension/tree/v1.1.0)

Expand Down
55 changes: 51 additions & 4 deletions json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"$comment": "This is the schema for STAC extension MLM in Items.",
"allOf": [
{
"$comment": "Schema to validate the MLM fields under Item properties or Assets properties.",
"type": "object",
"required": [
"type",
Expand Down Expand Up @@ -40,10 +41,6 @@
"allOf": [
{
"$ref": "#/$defs/fields"
},
{
"$comment": "At least one Asset must provide the model definition.",
"$ref": "#/$defs/AssetModelRole"
}
]
}
Expand All @@ -52,6 +49,10 @@
},
{
"$ref": "#/$defs/stac_extensions_mlm"
},
{
"$comment": "Schema to validate model role requirement.",
"$ref": "#/$defs/AssetModelRoleMinimumOneDefinition"
}
]
},
Expand Down Expand Up @@ -646,6 +647,52 @@
"DataType": {
"$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/data_type"
},
"AssetModelRoleMinimumOneDefinition": {
"$comment": "At least one Asset must provide the model definition indicated by the 'mlm:model' role.",
"required": [
"assets"
],
"anyOf": [
{
"properties": {
"assets": {
"additionalProperties": {
"properties": {
"roles": {
"type": "array",
"items": {
"const": "mlm:model"
},
"minItems": 1
}
}
}
}
}
},
{
"not": {
"properties": {
"assets": {
"additionalProperties": {
"properties": {
"roles": {
"type": "array",
"items": {
"type": "string",
"not": {
"const": "mlm:model"
}
}
}
}
}
}
}
}
}
]
},
"AssetModelRole": {
"required": ["roles"],
"properties": {
Expand Down
55 changes: 55 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,61 @@ def test_mlm_no_input_allowed_but_explicit_empty_array_required(
pystac.validation.validate(mlm_item, validator=mlm_validator)


@pytest.mark.parametrize(
"mlm_example",
["item_basic.json"],
indirect=True,
)
def test_mlm_other_non_mlm_assets_allowed(
mlm_validator: STACValidator,
mlm_example: Dict[str, JSON],
) -> None:
mlm_data = copy.deepcopy(mlm_example)
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand

mlm_data["assets"]["sample"] = {
"type": "image/jpeg",
"href": "https://example.com/sample/output.jpg",
"roles": ["preview"],
"title": "Model Output Predictions Sample",
}
mlm_data["assets"]["model-cart"] = {
"type": "text/markdown",
"href": "https://example.com/sample/model.md",
"roles": ["metadata"],
"title": "Model Cart",
}
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # still valid


@pytest.mark.parametrize(
"mlm_example",
["item_basic.json"],
indirect=True,
)
def test_mlm_at_least_one_asset_model(
mlm_validator: STACValidator,
mlm_example: Dict[str, JSON],
) -> None:
mlm_data = copy.deepcopy(mlm_example)
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand

mlm_data["assets"] = { # needs at least 1 asset with role 'mlm:model'
"model": {
"type": "application/octet-stream; application=pytorch",
"href": "https://example.com/sample/checkpoint.pt",
"roles": ["checkpoint"],
"title": "Model Weights Checkpoint",
}
}
with pytest.raises(pystac.errors.STACValidationError):
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator)


def test_model_metadata_to_dict(eurosat_resnet):
assert eurosat_resnet.item.to_dict()

Expand Down

0 comments on commit 28c9a81

Please sign in to comment.