From a58635b7022416fdd5db551ed1ee05df2223ca2c Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Wed, 24 Apr 2024 19:44:39 -0400 Subject: [PATCH] [wip] add more tests to ensure behavior --- examples/item_eo_bands.json | 172 ++++------ examples/item_eo_bands_summarized.json | 442 +++++++++++++++++++++++++ json-schema/schema.json | 134 ++++++-- tests/test_schema.py | 41 +++ 4 files changed, 650 insertions(+), 139 deletions(-) create mode 100644 examples/item_eo_bands_summarized.json diff --git a/examples/item_eo_bands.json b/examples/item_eo_bands.json index 800ecde..26c1215 100644 --- a/examples/item_eo_bands.json +++ b/examples/item_eo_bands.json @@ -1,4 +1,5 @@ { + "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands directly in the Model Asset.", "stac_version": "1.0.0", "stac_extensions": [ "https://crim-ca.github.io/mlm-extension/v1.0.0/schema.json", @@ -260,99 +261,6 @@ ], "post_processing_function": null } - ], - "eo:bands": [ - { - "name": "B01", - "common_name": "coastal", - "description": "Coastal aerosol (band 1)", - "center_wavelength": 0.443, - "full_width_half_max": 0.027 - }, - { - "name": "B02", - "common_name": "blue", - "description": "Blue (band 2)", - "center_wavelength": 0.49, - "full_width_half_max": 0.098 - }, - { - "name": "B03", - "common_name": "green", - "description": "Green (band 3)", - "center_wavelength": 0.56, - "full_width_half_max": 0.045 - }, - { - "name": "B04", - "common_name": "red", - "description": "Red (band 4)", - "center_wavelength": 0.665, - "full_width_half_max": 0.038 - }, - { - "name": "B05", - "common_name": "rededge", - "description": "Red edge 1 (band 5)", - "center_wavelength": 0.704, - "full_width_half_max": 0.019 - }, - { - "name": "B06", - "common_name": "rededge", - "description": "Red edge 2 (band 6)", - "center_wavelength": 0.74, - "full_width_half_max": 0.018 - }, - { - "name": "B07", - "common_name": "rededge", - "description": "Red edge 3 (band 7)", - "center_wavelength": 0.783, - "full_width_half_max": 0.028 - }, - { - "name": "B08", - "common_name": "nir", - "description": "NIR 1 (band 8)", - "center_wavelength": 0.842, - "full_width_half_max": 0.145 - }, - { - "name": "B8A", - "common_name": "nir08", - "description": "NIR 2 (band 8A)", - "center_wavelength": 0.865, - "full_width_half_max": 0.033 - }, - { - "name": "B09", - "common_name": "nir09", - "description": "NIR 3 (band 9)", - "center_wavelength": 0.945, - "full_width_half_max": 0.026 - }, - { - "name": "B10", - "common_name": "cirrus", - "description": "SWIR - Cirrus (band 10)", - "center_wavelength": 1.375, - "full_width_half_max": 0.026 - }, - { - "name": "B11", - "common_name": "swir16", - "description": "SWIR 1 (band 11)", - "center_wavelength": 1.61, - "full_width_half_max": 0.143 - }, - { - "name": "B12", - "common_name": "swir22", - "description": "SWIR 2 (band 12)", - "center_wavelength": 2.19, - "full_width_half_max": 0.242 - } ] }, "assets": { @@ -368,43 +276,95 @@ "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", "eo:bands": [ { - "name": "coastal" + "name": "B01", + "common_name": "coastal", + "description": "Coastal aerosol (band 1)", + "center_wavelength": 0.443, + "full_width_half_max": 0.027 }, { - "name": "blue" + "name": "B02", + "common_name": "blue", + "description": "Blue (band 2)", + "center_wavelength": 0.49, + "full_width_half_max": 0.098 }, { - "name": "green" + "name": "B03", + "common_name": "green", + "description": "Green (band 3)", + "center_wavelength": 0.56, + "full_width_half_max": 0.045 }, { - "name": "red" + "name": "B04", + "common_name": "red", + "description": "Red (band 4)", + "center_wavelength": 0.665, + "full_width_half_max": 0.038 }, { - "name": "rededge1" + "name": "B05", + "common_name": "rededge", + "description": "Red edge 1 (band 5)", + "center_wavelength": 0.704, + "full_width_half_max": 0.019 }, { - "name": "rededge2" + "name": "B06", + "common_name": "rededge", + "description": "Red edge 2 (band 6)", + "center_wavelength": 0.74, + "full_width_half_max": 0.018 }, { - "name": "rededge3" + "name": "B07", + "common_name": "rededge", + "description": "Red edge 3 (band 7)", + "center_wavelength": 0.783, + "full_width_half_max": 0.028 }, { - "name": "nir" + "name": "B08", + "common_name": "nir", + "description": "NIR 1 (band 8)", + "center_wavelength": 0.842, + "full_width_half_max": 0.145 }, { - "name": "nir08" + "name": "B8A", + "common_name": "nir08", + "description": "NIR 2 (band 8A)", + "center_wavelength": 0.865, + "full_width_half_max": 0.033 }, { - "name": "nir09" + "name": "B09", + "common_name": "nir09", + "description": "NIR 3 (band 9)", + "center_wavelength": 0.945, + "full_width_half_max": 0.026 }, { - "name": "cirrus" + "name": "B10", + "common_name": "cirrus", + "description": "SWIR - Cirrus (band 10)", + "center_wavelength": 1.375, + "full_width_half_max": 0.026 }, { - "name": "swir16" + "name": "B11", + "common_name": "swir16", + "description": "SWIR 1 (band 11)", + "center_wavelength": 1.61, + "full_width_half_max": 0.143 }, { - "name": "swir22" + "name": "B12", + "common_name": "swir22", + "description": "SWIR 2 (band 12)", + "center_wavelength": 2.19, + "full_width_half_max": 0.242 } ] }, diff --git a/examples/item_eo_bands_summarized.json b/examples/item_eo_bands_summarized.json new file mode 100644 index 0000000..aa9dcd4 --- /dev/null +++ b/examples/item_eo_bands_summarized.json @@ -0,0 +1,442 @@ +{ + "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands summarized in the Item properties and referenced by name in the Model Asset.", + "stac_version": "1.0.0", + "stac_extensions": [ + "https://crim-ca.github.io/mlm-extension/v1.0.0/schema.json", + "https://stac-extensions.github.io/eo/v1.1.0/schema.json", + "https://stac-extensions.github.io/raster/v1.1.0/schema.json", + "https://stac-extensions.github.io/file/v1.0.0/schema.json", + "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" + ], + "type": "Feature", + "id": "resnet-18_sentinel-2_all_moco_classification", + "collection": "ml-model-examples", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -7.882190080512502, + 37.13739173208318 + ], + [ + -7.882190080512502, + 58.21798141355221 + ], + [ + 27.911651652899923, + 58.21798141355221 + ], + [ + 27.911651652899923, + 37.13739173208318 + ], + [ + -7.882190080512502, + 37.13739173208318 + ] + ] + ] + }, + "bbox": [ + -7.882190080512502, + 37.13739173208318, + 27.911651652899923, + 58.21798141355221 + ], + "properties": { + "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", + "datetime": null, + "start_datetime": "1900-01-01T00:00:00Z", + "end_datetime": "9999-12-31T23:59:59Z", + "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", + "mlm:tasks": [ + "classification" + ], + "mlm:architecture": "ResNet", + "mlm:framework": "pytorch", + "mlm:framework_version": "2.1.2+cu121", + "file:size": 43000000, + "mlm:memory_size": 1, + "mlm:total_parameters": 11700000, + "mlm:pretrained_source": "EuroSat Sentinel-2", + "mlm:accelerator": "cuda", + "mlm:accelerator_constrained": false, + "mlm:accelerator_summary": "Unknown", + "mlm:batch_size_suggestion": 256, + "mlm:input": [ + { + "name": "13 Band Sentinel-2 Batch", + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12" + ], + "input": { + "shape": [ + -1, + 13, + 64, + 64 + ], + "dim_order": [ + "batch", + "channel", + "height", + "width" + ], + "data_type": "float32" + }, + "norm_by_channel": true, + "norm_type": "z-score", + "resize_type": null, + "statistics": [ + { + "mean": 1354.40546513, + "stddev": 245.71762908 + }, + { + "mean": 1118.24399958, + "stddev": 333.00778264 + }, + { + "mean": 1042.92983953, + "stddev": 395.09249139 + }, + { + "mean": 947.62620298, + "stddev": 593.75055589 + }, + { + "mean": 1199.47283961, + "stddev": 566.4170017 + }, + { + "mean": 1999.79090914, + "stddev": 861.18399006 + }, + { + "mean": 2369.22292565, + "stddev": 1086.63139075 + }, + { + "mean": 2296.82608323, + "stddev": 1117.98170791 + }, + { + "mean": 732.08340178, + "stddev": 404.91978886 + }, + { + "mean": 12.11327804, + "stddev": 4.77584468 + }, + { + "mean": 1819.01027855, + "stddev": 1002.58768311 + }, + { + "mean": 1118.92391149, + "stddev": 761.30323499 + }, + { + "mean": 2594.14080798, + "stddev": 1231.58581042 + } + ], + "pre_processing_function": { + "format": "python", + "expression": "torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" + } + } + ], + "mlm:output": [ + { + "name": "classification", + "tasks": [ + "classification" + ], + "result": { + "shape": [ + -1, + 10 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "float32" + }, + "classification_classes": [ + { + "value": 0, + "name": "Annual Crop", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 1, + "name": "Forest", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 2, + "name": "Herbaceous Vegetation", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 3, + "name": "Highway", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 4, + "name": "Industrial Buildings", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 5, + "name": "Pasture", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 6, + "name": "Permanent Crop", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 7, + "name": "Residential Buildings", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 8, + "name": "River", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + }, + { + "value": 9, + "name": "SeaLake", + "description": null, + "title": null, + "color_hint": null, + "nodata": false + } + ], + "post_processing_function": null + } + ], + "eo:bands": [ + { + "name": "B01", + "common_name": "coastal", + "description": "Coastal aerosol (band 1)", + "center_wavelength": 0.443, + "full_width_half_max": 0.027 + }, + { + "name": "B02", + "common_name": "blue", + "description": "Blue (band 2)", + "center_wavelength": 0.49, + "full_width_half_max": 0.098 + }, + { + "name": "B03", + "common_name": "green", + "description": "Green (band 3)", + "center_wavelength": 0.56, + "full_width_half_max": 0.045 + }, + { + "name": "B04", + "common_name": "red", + "description": "Red (band 4)", + "center_wavelength": 0.665, + "full_width_half_max": 0.038 + }, + { + "name": "B05", + "common_name": "rededge", + "description": "Red edge 1 (band 5)", + "center_wavelength": 0.704, + "full_width_half_max": 0.019 + }, + { + "name": "B06", + "common_name": "rededge", + "description": "Red edge 2 (band 6)", + "center_wavelength": 0.74, + "full_width_half_max": 0.018 + }, + { + "name": "B07", + "common_name": "rededge", + "description": "Red edge 3 (band 7)", + "center_wavelength": 0.783, + "full_width_half_max": 0.028 + }, + { + "name": "B08", + "common_name": "nir", + "description": "NIR 1 (band 8)", + "center_wavelength": 0.842, + "full_width_half_max": 0.145 + }, + { + "name": "B8A", + "common_name": "nir08", + "description": "NIR 2 (band 8A)", + "center_wavelength": 0.865, + "full_width_half_max": 0.033 + }, + { + "name": "B09", + "common_name": "nir09", + "description": "NIR 3 (band 9)", + "center_wavelength": 0.945, + "full_width_half_max": 0.026 + }, + { + "name": "B10", + "common_name": "cirrus", + "description": "SWIR - Cirrus (band 10)", + "center_wavelength": 1.375, + "full_width_half_max": 0.026 + }, + { + "name": "B11", + "common_name": "swir16", + "description": "SWIR 1 (band 11)", + "center_wavelength": 1.61, + "full_width_half_max": 0.143 + }, + { + "name": "B12", + "common_name": "swir22", + "description": "SWIR 2 (band 12)", + "center_wavelength": 2.19, + "full_width_half_max": 0.242 + } + ] + }, + "assets": { + "weights": { + "href": "https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", + "title": "Pytorch weights checkpoint", + "description": "A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", + "type": "application/octet-stream; application=pytorch", + "roles": [ + "mlm:model", + "mlm:weights" + ], + "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", + "eo:bands": [ + { + "name": "coastal" + }, + { + "name": "blue" + }, + { + "name": "green" + }, + { + "name": "red" + }, + { + "name": "rededge1" + }, + { + "name": "rededge2" + }, + { + "name": "rededge3" + }, + { + "name": "nir" + }, + { + "name": "nir08" + }, + { + "name": "nir09" + }, + { + "name": "cirrus" + }, + { + "name": "swir16" + }, + { + "name": "swir22" + } + ] + }, + "source_code": { + "href": "https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207", + "title": "Model implementation.", + "description": "Source code to run the model.", + "type": "text/x-python", + "roles": [ + "mlm:model", + "code", + "metadata" + ] + } + }, + "links": [ + { + "rel": "collection", + "href": "./collection.json", + "type": "application/json" + }, + { + "rel": "self", + "href": "./item_eo_bands.json", + "type": "application/geo+json" + }, + { + "rel": "derived_from", + "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", + "type": "application/json", + "ml-aoi:split": "train" + } + ] +} diff --git a/json-schema/schema.json b/json-schema/schema.json index ed39c70..8f90bf0 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -117,6 +117,20 @@ } } }, + "stac_extensions_eo_bands": { + "required": ["eo:bands"], + "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "properties": { + "$comment": "https://github.com/stac-extensions/eo#item-properties-or-asset-fields", + "eo:bands": { + "type": "array", + "minItems": 1, + "items": { + "type": "object" + } + } + } + }, "stac_extensions_raster": { "type": "object", "required": [ @@ -132,6 +146,20 @@ } } }, + "stac_extensions_raster_bands": { + "required": ["raster:bands"], + "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "properties": { + "$comment": "https://github.com/stac-extensions/raster#item-asset-fields", + "raster:bands": { + "type": "array", + "minItems": 1, + "items": { + "type": "object" + } + } + } + }, "stac_version_1.1": { "$comment": "Requirement for STAC 1.1 or above.", "type": "object", @@ -624,7 +652,7 @@ "ModelBands": { "allOf": [ { - "$comment": "No 'minItems' here since to support model inputs not using any band (other data source).", + "$comment": "No 'minItems' here to support model inputs not using any band (other data source).", "type": "array", "items": { "type": "string", @@ -678,6 +706,7 @@ "$comment": "Need at least one 'bands', but multiple is allowed.", "anyOf": [ { + "$comment": "Bands described by raster extension.", "allOf": [ { "$ref": "#/$defs/stac_extensions_raster" @@ -705,33 +734,44 @@ ] }, { + "$comment": "Bands described by eo extension.", "allOf": [ { "$ref": "#/$defs/stac_extensions_eo" }, { - "$comment": "This is the JSON-object 'properties' definition.", - "properties": { - "$comment": "This is the STAC-Item 'properties' field.", - "properties": { - "required": ["eo:bands"], - "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "$comment": "EO extension expects at 'eo:bands' in (at least) 1 asset, and possibly in Item properties. Items are for summarizing. Since MLM also uses it by 'name' reference, allow any combination, and let 'eo' validate remaining combinations.", + "anyOf": [ + { + "$comment": "This is the JSON-object 'properties' definition.", "properties": { - "$comment": "https://github.com/stac-extensions/eo#item-properties-or-asset-fields", - "eo:bands": { - "type": "array", - "minItems": 1, - "items": { - "type": "object" + "$comment": "This is the STAC-Item 'properties' field.", + "properties": { + "$ref": "#/$defs/stac_extensions_eo_bands" + } + } + }, + { + "$comment": "For the case where 'eo:bands' is in the Asset of the model, it must also contain the 'mlm:model' role.", + "properties": { + "assets": { + "additionalProperties": { + "if": { + "$ref": "#/$defs/AssetModelRole" + }, + "then": { + "$ref": "#/$defs/stac_extensions_eo_bands" + } } } } } - } + ] } ] }, { + "$comment": "Bands described by STAC Core 1.1.", "allOf": [ { "$ref": "#/$defs/stac_version_1.1" @@ -761,32 +801,60 @@ ] }, "else": { - "$comment": "This is the JSON-object 'properties' definition.", - "properties": { - "$comment": "This is the STAC-Item 'properties' field.", - "properties": { - "required": [ - "mlm:input" - ], - "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "$comment": "Case where no 'bands' are referenced in the MLM input. Counter-validate there are no 'eo:bands' or 'raster:bands' in the Model Asset.", + "allOf": [ + { + "$comment": "This is the JSON-object 'properties' definition.", "properties": { - "$comment": "Required MLM bands listing referring to at least one band name.", - "mlm:input": { - "type": "array", - "items": { - "$comment": "This is the 'Model Input Object' properties.", - "properties": { - "bands": { - "$comment": "No bands reference provided, therefore none permitted in model inputs.", - "type": "array", - "maxItems": 0 + "$comment": "This is the STAC-Item 'properties' field.", + "properties": { + "required": [ + "mlm:input" + ], + "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "properties": { + "$comment": "Required MLM bands listing referring to at least one band name.", + "mlm:input": { + "type": "array", + "items": { + "$comment": "This is the 'Model Input Object' properties.", + "properties": { + "bands": { + "$comment": "No bands reference provided, therefore none permitted in model inputs.", + "type": "array", + "maxItems": 0 + } + } + } + } + } + } + } + }, + { + "properties": { + "assets": { + "additionalProperties": { + "if": { + "$ref": "#/$defs/AssetModelRole" + }, + "then": { + "not": { + "anyOf": [ + { + "$ref": "#/$defs/stac_extensions_eo_bands" + }, + { + "$ref": "#/$defs/stac_extensions_raster_bands" + } + ] } } } } } } - } + ] } } } diff --git a/tests/test_schema.py b/tests/test_schema.py index 3fb2b6b..3f5253d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Dict, cast import pystac @@ -14,6 +15,7 @@ "item_basic.json", "item_raster_bands.json", "item_eo_bands.json", + "item_eo_bands_summarized.json", "item_eo_and_raster_bands.json", "item_multi_io.json", ], @@ -29,6 +31,45 @@ def test_mlm_schema( assert SCHEMA_URI in validated +@pytest.mark.parametrize( + "mlm_example", + ["item_eo_bands_summarized.json"], + indirect=True, +) +def test_mlm_eo_bands_invalid_only_in_item_properties( + mlm_validator: STACValidator, + mlm_example: JSON, +) -> None: + mlm_item = pystac.Item.from_dict(mlm_example) + pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid + + mlm_eo_bands_bad_data: Dict[str, JSON] = copy.deepcopy(mlm_example) + mlm_eo_bands_bad_data["assets"]["weights"].pop("eo:bands") + with pytest.raises(pystac.errors.STACValidationError): + mlm_eo_bands_bad_item = pystac.Item.from_dict(mlm_eo_bands_bad_data) + pystac.validation.validate(mlm_eo_bands_bad_item, validator=mlm_validator) + + +@pytest.mark.parametrize( + "mlm_example", + ["item_basic.json"], + indirect=True, +) +def test_mlm_no_input_allowed_but_explicit_empty_array_required( + mlm_validator: STACValidator, + mlm_example: JSON, +) -> None: + mlm_data = copy.deepcopy(mlm_example) + mlm_data["properties"]["mlm:input"] = [] + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) + + with pytest.raises(pystac.errors.STACValidationError): + mlm_data["properties"].pop("mlm:input") + mlm_item = pystac.Item.from_dict(mlm_data) + pystac.validation.validate(mlm_item, validator=mlm_validator) + + def test_model_metadata_to_dict(eurosat_resnet): assert eurosat_resnet.item.to_dict()