From b00fe021e8ef3d8f9557d629c5accdfe0a362ba8 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Thu, 25 Apr 2024 12:08:30 -0400 Subject: [PATCH] fixed and working bands validation --- CHANGELOG.md | 1 + README.md | 4 +++- examples/collection.json | 8 ++++++++ examples/item_multi_io.json | 24 ++++++++++++++++++++++-- json-schema/schema.json | 29 +++++++++++++++++++---------- stac_model/base.py | 2 +- tests/conftest.py | 13 ++++++++++++- tests/test_schema.py | 29 ++++++++++++++++++++--------- 8 files changed, 86 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff5edca..715c205 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 without leading or trailing non-alphanumeric characters. - Add [`examples/item_eo_and_raster_bands.json`](examples/item_eo_and_raster_bands.json) demonstrating the original use case represented by the previous [`examples/item_eo_bands.json`](examples/item_eo_bands.json) contents. +- Add a `description` field for `mlm:input` and `mlm:output` definitions. ### Changed - Adjust `scikit-learn` and `Hugging Face` framework names to match the format employed by the official documentation. diff --git a/README.md b/README.md index 19821ef..cc0bcdf 100644 --- a/README.md +++ b/README.md @@ -209,9 +209,10 @@ set to `true`, there would be no `accelerator` to contain against. To avoid conf | Field Name | Type | Description | |-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | +| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | | bands | \[string] | **REQUIRED** The names of the raster bands used to train or fine-tune the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. | | input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. | +| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. | | norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. | | norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | | norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. | @@ -400,6 +401,7 @@ the following formats are recommended as alternative scripts and function refere | name | string | **REQUIRED** Name of the output variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"CLASSIFICATION"`) can be used instead. | | tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the output can be used for. This can be a subset of `mlm:tasks` defined under the Item `properties` as applicable. | | result | [Result Structure Object](#result-structure-object) | **REQUIRED** The structure that describes the resulting output arrays/tensors from one model head. | +| description | string | Additional details about the output such as describing its purpose or expected result that cannot be represented by other properties. | | classification:classes | \[[Class Object](#class-object)] | A list of class objects adhering to the [Classification Extension](https://github.com/stac-extensions/classification). | | post_processing_function | [Processing Expression](#processing-expression) \| null | Custom postprocessing function where normalization and rescaling, and any other significant operations takes place. | diff --git a/examples/collection.json b/examples/collection.json index 7a71b3f..46c78ff 100644 --- a/examples/collection.json +++ b/examples/collection.json @@ -56,6 +56,14 @@ "href": "item_eo_bands.json", "rel": "item" }, + { + "href": "item_eo_and_raster_bands.json", + "rel": "item" + }, + { + "href": "item_eo_bands_summarized.json", + "rel": "item" + }, { "href": "item_raster_bands.json", "rel": "item" diff --git a/examples/item_multi_io.json b/examples/item_multi_io.json index d2f275d..fb246cd 100644 --- a/examples/item_multi_io.json +++ b/examples/item_multi_io.json @@ -7,7 +7,7 @@ "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" ], "type": "Feature", - "id": "resnet-18_sentinel-2_all_moco_classification", + "id": "model-multi-input", "collection": "ml-model-examples", "geometry": { "type": "Polygon", @@ -43,7 +43,7 @@ 58.21798141355221 ], "properties": { - "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", + "description": "Generic model that employs multiple input sources with different combination of bands.", "datetime": null, "start_datetime": "1900-01-01T00:00:00Z", "end_datetime": "9999-12-31T23:59:59Z", @@ -114,6 +114,26 @@ ], "data_type": "uint16" } + }, + { + "name": "DEM", + "description": "Digital elevation model. Comes from another source than the Sentinel bands. Therefore, no 'bands' associated to it.", + "bands": [], + "input": { + "shape": [ + -1, + 1, + 64, + 64 + ], + "dim_order": [ + "batch", + "ndvi", + "height", + "width" + ], + "data_type": "float32" + } } ], "mlm:output": [ diff --git a/json-schema/schema.json b/json-schema/schema.json index 8f90bf0..58bd76c 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -425,6 +425,10 @@ "input": { "$ref": "#/$defs/InputStructure" }, + "description": { + "type": "string", + "minLength": 1 + }, "norm_by_channel": { "type": "boolean" }, @@ -467,6 +471,10 @@ "result": { "$ref": "#/$defs/ResultStructure" }, + "description": { + "type": "string", + "minLength": 1 + }, "classification:classes": { "$ref": "#/$defs/ClassificationClasses" }, @@ -668,10 +676,11 @@ "AnyBandsRef": { "$comment": "This definition ensures that, if at least 1 named MLM input 'bands' is provided, at least 1 of the supported references from EO, Raster or STAC Core 1.1 are provided as well. Otherwise, 'bands' must be explicitly empty.", "if": { - "$comment": "This is the JSON-object 'properties' definition.", + "type": "object", "properties": { "$comment": "This is the STAC-Item 'properties' field.", "properties": { + "type": "object", "required": [ "mlm:input" ], @@ -688,12 +697,7 @@ "properties": { "bands": { "type": "array", - "$comment": "This 'minItems' is the purpose of this whole 'if/then' block.", - "minItems": 1, - "items": { - "type": "string", - "minLength": 1 - } + "minItems": 1 } } } @@ -703,7 +707,7 @@ } }, "then": { - "$comment": "Need at least one 'bands', but multiple is allowed.", + "$comment": "Need at least one 'bands' definition, but multiple are allowed.", "anyOf": [ { "$comment": "Bands described by raster extension.", @@ -740,7 +744,7 @@ "$ref": "#/$defs/stac_extensions_eo" }, { - "$comment": "EO extension expects at 'eo:bands' in (at least) 1 asset, and possibly in Item properties. Items are for summarizing. Since MLM also uses it by 'name' reference, allow any combination, and let 'eo' validate remaining combinations.", + "$comment": "EO extension expects at 'eo:bands' in (at least) 1 asset, and possibly in Item properties. Items are for summarizing. Since MLM also uses bands by 'name' reference, allow any combination, and let 'eo' validate remaining combinations.", "anyOf": [ { "$comment": "This is the JSON-object 'properties' definition.", @@ -781,7 +785,9 @@ "properties": { "$comment": "This is the STAC-Item 'properties' field.", "properties": { - "required": ["bands"], + "required": [ + "bands" + ], "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", "properties": { "$comment": "https://github.com/radiantearth/stac-spec/blob/bands/item-spec/common-metadata.md#bands", @@ -817,6 +823,9 @@ "mlm:input": { "type": "array", "items": { + "required": [ + "bands" + ], "$comment": "This is the 'Model Input Object' properties.", "properties": { "bands": { diff --git a/stac_model/base.py b/stac_model/base.py index 4e8cc6b..96c1d0f 100644 --- a/stac_model/base.py +++ b/stac_model/base.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Literal, TypeAlias, Union +from typing import Any, Dict, List, Literal, TypeAlias, TypedDict, Union from pydantic import BaseModel, ConfigDict, model_serializer diff --git a/tests/conftest.py b/tests/conftest.py index 996b1a6..3951a31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ +import glob import json import os -from typing import TYPE_CHECKING, Any, Dict, cast +from typing import TYPE_CHECKING, Any, Dict, List, cast import pystac import pytest @@ -17,6 +18,16 @@ JSON_SCHEMA_DIR = os.path.abspath(os.path.join(TEST_DIR, "../json-schema")) +def get_all_stac_item_examples() -> List[str]: + all_json = glob.glob("**/*.json", root_dir=EXAMPLES_DIR, recursive=True) + all_geojson = glob.glob("**/*.geojson", root_dir=EXAMPLES_DIR, recursive=True) + all_stac_items = [ + path for path in all_json + all_geojson + if os.path.splitext(os.path.basename(path))[0] not in ["collection", "catalog"] + ] + return all_stac_items + + @pytest.fixture(scope="session") def mlm_schema() -> JSON: with open(os.path.join(JSON_SCHEMA_DIR, "schema.json")) as schema_file: diff --git a/tests/test_schema.py b/tests/test_schema.py index 3f5253d..5408378 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Dict, cast +import os +from typing import Any, Dict, List, cast import pystac import pytest @@ -8,17 +9,12 @@ from stac_model.base import JSON from stac_model.schema import SCHEMA_URI +from conftest import get_all_stac_item_examples + @pytest.mark.parametrize( "mlm_example", # value passed to 'mlm_example' fixture - [ - "item_basic.json", - "item_raster_bands.json", - "item_eo_bands.json", - "item_eo_bands_summarized.json", - "item_eo_and_raster_bands.json", - "item_multi_io.json", - ], + get_all_stac_item_examples(), indirect=True, ) def test_mlm_schema( @@ -82,3 +78,18 @@ def test_validate_model_against_schema(eurosat_resnet, mlm_validator): mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict()) validated = pystac.validation.validate(mlm_item, validator=mlm_validator) assert SCHEMA_URI in validated + + +@pytest.mark.parametrize( + "mlm_example", + ["collection.json"], + indirect=True, +) +def test_collection_include_all_items(mlm_example: JSON): + """ + This is only for self-validation, to make sure all examples are contained in the example STAC collection. + """ + col_links: List[JSON] = mlm_example["links"] + col_items = {os.path.basename(link["href"]) for link in col_links if link["rel"] == "item"} + all_items = {os.path.basename(path) for path in get_all_stac_item_examples()} + assert all_items == col_items, "Missing STAC Item examples in the example STAC Collection links."