Skip to content

Commit

Permalink
fixed and working bands validation
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Apr 25, 2024
1 parent a58635b commit b00fe02
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
without leading or trailing non-alphanumeric characters.
- Add [`examples/item_eo_and_raster_bands.json`](examples/item_eo_and_raster_bands.json) demonstrating the original
use case represented by the previous [`examples/item_eo_bands.json`](examples/item_eo_bands.json) contents.
- Add a `description` field for `mlm:input` and `mlm:output` definitions.

### Changed
- Adjust `scikit-learn` and `Hugging Face` framework names to match the format employed by the official documentation.
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ set to `true`, there would be no `accelerator` to contain against. To avoid conf

| Field Name | Type | Description |
|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. |
| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. |
| bands | \[string] | **REQUIRED** The names of the raster bands used to train or fine-tune the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. |
| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. |
| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. |
| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. |
| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. |
| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. |
Expand Down Expand Up @@ -400,6 +401,7 @@ the following formats are recommended as alternative scripts and function refere
| name | string | **REQUIRED** Name of the output variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"CLASSIFICATION"`) can be used instead. |
| tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the output can be used for. This can be a subset of `mlm:tasks` defined under the Item `properties` as applicable. |
| result | [Result Structure Object](#result-structure-object) | **REQUIRED** The structure that describes the resulting output arrays/tensors from one model head. |
| description | string | Additional details about the output such as describing its purpose or expected result that cannot be represented by other properties. |
| classification:classes | \[[Class Object](#class-object)] | A list of class objects adhering to the [Classification Extension](https://github.com/stac-extensions/classification). |
| post_processing_function | [Processing Expression](#processing-expression) \| null | Custom postprocessing function where normalization and rescaling, and any other significant operations takes place. |

Expand Down
8 changes: 8 additions & 0 deletions examples/collection.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
"href": "item_eo_bands.json",
"rel": "item"
},
{
"href": "item_eo_and_raster_bands.json",
"rel": "item"
},
{
"href": "item_eo_bands_summarized.json",
"rel": "item"
},
{
"href": "item_raster_bands.json",
"rel": "item"
Expand Down
24 changes: 22 additions & 2 deletions examples/item_multi_io.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json"
],
"type": "Feature",
"id": "resnet-18_sentinel-2_all_moco_classification",
"id": "model-multi-input",
"collection": "ml-model-examples",
"geometry": {
"type": "Polygon",
Expand Down Expand Up @@ -43,7 +43,7 @@
58.21798141355221
],
"properties": {
"description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO",
"description": "Generic model that employs multiple input sources with different combination of bands.",
"datetime": null,
"start_datetime": "1900-01-01T00:00:00Z",
"end_datetime": "9999-12-31T23:59:59Z",
Expand Down Expand Up @@ -114,6 +114,26 @@
],
"data_type": "uint16"
}
},
{
"name": "DEM",
"description": "Digital elevation model. Comes from another source than the Sentinel bands. Therefore, no 'bands' associated to it.",
"bands": [],
"input": {
"shape": [
-1,
1,
64,
64
],
"dim_order": [
"batch",
"ndvi",
"height",
"width"
],
"data_type": "float32"
}
}
],
"mlm:output": [
Expand Down
29 changes: 19 additions & 10 deletions json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@
"input": {
"$ref": "#/$defs/InputStructure"
},
"description": {
"type": "string",
"minLength": 1
},
"norm_by_channel": {
"type": "boolean"
},
Expand Down Expand Up @@ -467,6 +471,10 @@
"result": {
"$ref": "#/$defs/ResultStructure"
},
"description": {
"type": "string",
"minLength": 1
},
"classification:classes": {
"$ref": "#/$defs/ClassificationClasses"
},
Expand Down Expand Up @@ -668,10 +676,11 @@
"AnyBandsRef": {
"$comment": "This definition ensures that, if at least 1 named MLM input 'bands' is provided, at least 1 of the supported references from EO, Raster or STAC Core 1.1 are provided as well. Otherwise, 'bands' must be explicitly empty.",
"if": {
"$comment": "This is the JSON-object 'properties' definition.",
"type": "object",
"properties": {
"$comment": "This is the STAC-Item 'properties' field.",
"properties": {
"type": "object",
"required": [
"mlm:input"
],
Expand All @@ -688,12 +697,7 @@
"properties": {
"bands": {
"type": "array",
"$comment": "This 'minItems' is the purpose of this whole 'if/then' block.",
"minItems": 1,
"items": {
"type": "string",
"minLength": 1
}
"minItems": 1
}
}
}
Expand All @@ -703,7 +707,7 @@
}
},
"then": {
"$comment": "Need at least one 'bands', but multiple is allowed.",
"$comment": "Need at least one 'bands' definition, but multiple are allowed.",
"anyOf": [
{
"$comment": "Bands described by raster extension.",
Expand Down Expand Up @@ -740,7 +744,7 @@
"$ref": "#/$defs/stac_extensions_eo"
},
{
"$comment": "EO extension expects at 'eo:bands' in (at least) 1 asset, and possibly in Item properties. Items are for summarizing. Since MLM also uses it by 'name' reference, allow any combination, and let 'eo' validate remaining combinations.",
"$comment": "EO extension expects at 'eo:bands' in (at least) 1 asset, and possibly in Item properties. Items are for summarizing. Since MLM also uses bands by 'name' reference, allow any combination, and let 'eo' validate remaining combinations.",
"anyOf": [
{
"$comment": "This is the JSON-object 'properties' definition.",
Expand Down Expand Up @@ -781,7 +785,9 @@
"properties": {
"$comment": "This is the STAC-Item 'properties' field.",
"properties": {
"required": ["bands"],
"required": [
"bands"
],
"$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.",
"properties": {
"$comment": "https://github.com/radiantearth/stac-spec/blob/bands/item-spec/common-metadata.md#bands",
Expand Down Expand Up @@ -817,6 +823,9 @@
"mlm:input": {
"type": "array",
"items": {
"required": [
"bands"
],
"$comment": "This is the 'Model Input Object' properties.",
"properties": {
"bands": {
Expand Down
2 changes: 1 addition & 1 deletion stac_model/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Literal, TypeAlias, Union
from typing import Any, Dict, List, Literal, TypeAlias, TypedDict, Union

from pydantic import BaseModel, ConfigDict, model_serializer

Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import json
import os
from typing import TYPE_CHECKING, Any, Dict, cast
from typing import TYPE_CHECKING, Any, Dict, List, cast

import pystac
import pytest
Expand All @@ -17,6 +18,16 @@
JSON_SCHEMA_DIR = os.path.abspath(os.path.join(TEST_DIR, "../json-schema"))


def get_all_stac_item_examples() -> List[str]:
all_json = glob.glob("**/*.json", root_dir=EXAMPLES_DIR, recursive=True)
all_geojson = glob.glob("**/*.geojson", root_dir=EXAMPLES_DIR, recursive=True)
all_stac_items = [
path for path in all_json + all_geojson
if os.path.splitext(os.path.basename(path))[0] not in ["collection", "catalog"]
]
return all_stac_items


@pytest.fixture(scope="session")
def mlm_schema() -> JSON:
with open(os.path.join(JSON_SCHEMA_DIR, "schema.json")) as schema_file:
Expand Down
29 changes: 20 additions & 9 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from typing import Any, Dict, cast
import os
from typing import Any, Dict, List, cast

import pystac
import pytest
Expand All @@ -8,17 +9,12 @@
from stac_model.base import JSON
from stac_model.schema import SCHEMA_URI

from conftest import get_all_stac_item_examples


@pytest.mark.parametrize(
"mlm_example", # value passed to 'mlm_example' fixture
[
"item_basic.json",
"item_raster_bands.json",
"item_eo_bands.json",
"item_eo_bands_summarized.json",
"item_eo_and_raster_bands.json",
"item_multi_io.json",
],
get_all_stac_item_examples(),
indirect=True,
)
def test_mlm_schema(
Expand Down Expand Up @@ -82,3 +78,18 @@ def test_validate_model_against_schema(eurosat_resnet, mlm_validator):
mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict())
validated = pystac.validation.validate(mlm_item, validator=mlm_validator)
assert SCHEMA_URI in validated


@pytest.mark.parametrize(
"mlm_example",
["collection.json"],
indirect=True,
)
def test_collection_include_all_items(mlm_example: JSON):
"""
This is only for self-validation, to make sure all examples are contained in the example STAC collection.
"""
col_links: List[JSON] = mlm_example["links"]
col_items = {os.path.basename(link["href"]) for link in col_links if link["rel"] == "item"}
all_items = {os.path.basename(path) for path in get_all_stac_item_examples()}
assert all_items == col_items, "Missing STAC Item examples in the example STAC Collection links."

0 comments on commit b00fe02

Please sign in to comment.