Skip to content

Commit

Permalink
update pydantic models with new json-schema fields
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Apr 4, 2024
1 parent 2b87297 commit 269bd73
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 152 deletions.
6 changes: 5 additions & 1 deletion json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,11 @@
]
},
"NormalizeClip": {

"type": "array",
"minItems": 1,
"items": {
"type": "number"
}
},
"ResizeType": {
"oneOf": [
Expand Down
66 changes: 66 additions & 0 deletions stac_model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from enum import Enum
from typing import Any, Literal, Union, TypeAlias

from pydantic import BaseModel


DataType: TypeAlias = Literal[
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
"cint16",
"cint32",
"cfloat32",
"cfloat64",
"other"
]


class TaskEnum(str, Enum):
REGRESSION = "regression"
CLASSIFICATION = "classification"
SCENE_CLASSIFICATION = "scene-classification"
DETECTION = "detection"
OBJECT_DETECTION = "object-detection"
SEGMENTATION = "segmentation"
SEMANTIC_SEGMENTATION = "semantic-segmentation"
INSTANCE_SEGMENTATION = "instance-segmentation"
PANOPTIC_SEGMENTATION = "panoptic-segmentation"
SIMILARITY_SEARCH = "similarity-search"
GENERATIVE = "generative"
IMAGE_CAPTIONING = "image-captioning"
SUPER_RESOLUTION = "super-resolution"


ModelTaskNames: TypeAlias = Literal[
"regression",
"classification",
"scene-classification",
"detection",
"object-detection",
"segmentation",
"semantic-segmentation",
"instance-segmentation",
"panoptic-segmentation",
"similarity-search",
"generative",
"image-captioning",
"super-resolution"
]


ModelTask = Union[ModelTaskNames, TaskEnum]


class ProcessingExpression(BaseModel):
# FIXME: should use 'pystac' reference, but 'processing' extension is not implemented yet!
format: str
expression: Any
110 changes: 61 additions & 49 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import pystac
import json
import shapely
from stac_model.base import ProcessingExpression
from stac_model.input import ModelInput
from stac_model.output import ModelOutput, ModelResult
from stac_model.schema import (
Asset,
ClassObject,
InputArray,
MLMClassification,
MLModelExtension,
MLModelProperties,
ModelInput,
ModelOutput,
ResultArray,
Runtime,
Statistics,
)


def eurosat_resnet() -> MLModelExtension[pystac.Item]:
input_array = InputArray(
shape=[-1, 13, 64, 64], dim_order="bchw", data_type="float32"
shape=[-1, 13, 64, 64],
dim_order=[
"batch",
"channel",
"height",
"width"
],
data_type="float32",
)
band_names = [
"B01",
Expand Down Expand Up @@ -69,29 +76,34 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
input = ModelInput(
name="13 Band Sentinel-2 Batch",
bands=band_names,
input_array=input_array,
input=input_array,
norm_by_channel=True,
norm_type="z_score",
resize_type="none",
norm_type="z-score",
resize_type=None,
statistics=stats,
pre_processing_function="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn", # noqa: E501
)
runtime = Runtime(
framework="torch",
version="2.1.2+cu121",
asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501
type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501
),
source_code=Asset(
href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501
),
accelerator="cuda",
accelerator_constrained=False,
hardware_summary="Unknown",
commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a",
pre_processing_function=ProcessingExpression(
format="python",
expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn"
), # noqa: E501
)
result_array = ResultArray(
shape=[-1, 10], dim_order=["batch", "class"], data_type="float32"
# runtime = Runtime(
# framework="torch",
# version="2.1.2+cu121",
# asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501
# type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501
# ),
# source_code=Asset(
# href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501
# ),
# accelerator="cuda",
# accelerator_constrained=False,
# hardware_summary="Unknown",
# commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a",
# )
result_array = ModelResult(
shape=[-1, 10],
dim_order=["batch", "class"],
data_type="float32"
)
class_map = {
"Annual Crop": 0,
Expand All @@ -106,30 +118,26 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
"SeaLake": 9,
}
class_objects = [
ClassObject(value=class_map[class_name], name=class_name)
for class_name in class_map
MLMClassification(value=class_value, name=class_name)
for class_name, class_value in class_map.items()
]
output = ModelOutput(
task="classification",
classification_classes=class_objects,
output_shape=[-1, 10],
result_array=[result_array],
name="classification",
tasks={"classification"},
classes=class_objects,
result=result_array,
post_processing_function=None,
)
ml_model_meta = MLModelProperties(
name="Resnet-18 Sentinel-2 ALL MOCO",
task="classification",
tasks={"classification"},
framework="pytorch",
framework_version="2.1.2+cu121",
file_size=43000000,
memory_size=1,
summary=(
"Sourced from torchgeo python library,"
"identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
),
pretrained_source="EuroSat Sentinel-2",
total_parameters=11_700_000,
input=[input],
runtime=[runtime],
output=[output],
)
# TODO, this can't be serialized but pystac.item calls for a datetime
Expand All @@ -138,26 +146,30 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
start_datetime = "1900-01-01"
end_datetime = None
bbox = [
-7.882190080512502,
37.13739173208318,
27.911651652899923,
58.21798141355221
]
geometry = json.dumps(shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__, indent=2)
name = (
"_".join(ml_model_meta.name.split(" ")).lower()
+ f"_{ml_model_meta.task}".lower()
)
-7.882190080512502,
37.13739173208318,
27.911651652899923,
58.21798141355221
]
geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__
name = "_".join(ml_model_meta.name.split(" ")).lower()
item = pystac.Item(
id=name,
geometry=geometry,
bbox=bbox,
datetime=None,
properties={"start_datetime": start_datetime, "end_datetime": end_datetime},
properties={
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": (
"Sourced from torchgeo python library,"
"identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
),
},
)
item.add_derived_from(
"https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a"
)
item_mlm = MLModelExtension.ext(item, add_if_missing=True)
item_mlm.apply(ml_model_meta.model_dump())
item_mlm.apply(ml_model_meta.model_dump(by_alias=True))
return item_mlm
65 changes: 40 additions & 25 deletions stac_model/input.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Dict, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Set, TypeAlias, Union

from pydantic import AnyUrl, BaseModel, Field
from pydantic import BaseModel, Field

from stac_model.base import DataType, ProcessingExpression


class InputArray(BaseModel):
shape: List[Union[int, float]]
dim_order: List[str]
data_type: str = Field(
...,
pattern="^(uint8|uint16|uint32|uint64|int8|int16|int32|int64|float16|float32|float64|cint16|cint32|cfloat32|cfloat64|other)$",
)
shape: List[Union[int, float]] = Field(..., min_items=1)
dim_order: List[str] = Field(..., min_items=1)
data_type: DataType


class Statistics(BaseModel):
Expand All @@ -24,29 +23,45 @@ class Statistics(BaseModel):
class Band(BaseModel):
name: str
description: Optional[str] = None
nodata: float | int | str
nodata: Union[float, int, str]
data_type: str
unit: Optional[str] = None


NormalizeType: TypeAlias = Optional[Literal[
"min-max",
"z-score",
"l1",
"l2",
"l2sqr",
"hamming",
"hamming2",
"type-mask",
"relative",
"inf"
]]

ResizeType: TypeAlias = Optional[Literal[
"crop",
"pad",
"interpolation-nearest",
"interpolation-linear",
"interpolation-cubic",
"interpolation-area",
"interpolation-lanczos4",
"interpolation-max",
"wrap-fill-outliers",
"wrap-inverse-map"
]]


class ModelInput(BaseModel):
name: str
bands: List[str]
input_array: InputArray
input: InputArray
norm_by_channel: bool = None
norm_type: Literal[
"min_max",
"z_score",
"max_norm",
"mean_norm",
"unit_variance",
"norm_with_clip",
"none",
] = None
resize_type: Literal["crop", "pad", "interpolate", "none"] = None
parameters: Optional[
Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]
] = None
norm_type: NormalizeType = None
norm_clip: Optional[List[Union[float, int]]] = None
resize_type: ResizeType = None
statistics: Optional[Union[Statistics, List[Statistics]]] = None
norm_with_clip_values: Optional[List[Union[float, int]]] = None
pre_processing_function: Optional[str | AnyUrl] = None
pre_processing_function: Optional[ProcessingExpression] = None
Loading

0 comments on commit 269bd73

Please sign in to comment.