Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

44 translate cropno crop inference udf to gfmap #48

Merged
merged 38 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
de4540d
added mvp repo for minimal inference workflow
HansVRP May 6, 2024
c749633
minimal presto functionality
HansVRP May 6, 2024
1020ee4
hv remove pandas to xarray conversion
HansVRP May 7, 2024
bc9bd1a
Succesful run, todo is fix apply metadata for bands
HansVRP May 7, 2024
2ec4ebd
rework UDF and include presto UDF
HansVRP May 8, 2024
f17e0a8
fix: resolve presto specific UDF and include udf_long which does not …
HansVRP May 16, 2024
6ae5da2
fix: test remote inference
HansVRP May 21, 2024
4a3b74b
fix: dynamic size
HansVRP May 21, 2024
f3f4b15
Work in xarray as much as possible
kvantricht May 21, 2024
e0c1d05
Fix typing errors
kvantricht May 21, 2024
433f001
fix: inference
HansVRP May 22, 2024
af151f7
fix: udf_long
HansVRP May 22, 2024
44f9651
Updated UDF (still flips result though!)
kvantricht May 23, 2024
e0ca616
user order="F" for reshaping fixes the flipping issue
kvantricht May 23, 2024
7968ba0
Avoid use of rearrange. Bug remains.
kvantricht May 24, 2024
a579be7
Avoid the use of np.swapaxes
kvantricht May 24, 2024
42218f0
Add a comment for clarification
kvantricht May 24, 2024
919391c
Updated inference notebook
kvantricht May 24, 2024
9f105e6
Merge branch 'kvt_mvp_inferenceUDF' of https://github.com/WorldCereal…
GriffinBabe May 27, 2024
b74ecad
Updated inference notebook
kvantricht May 27, 2024
29b3034
Merge branch 'hv_mvp_inferenceUDF' of github.com:WorldCereal/worldcer…
kvantricht May 27, 2024
3e03ab4
Merge pull request #46 from WorldCereal/kvt_mvp_inferenceUDF
HansVRP May 27, 2024
7915b93
Updating preprocessing to match better kristof's results
GriffinBabe May 28, 2024
005841a
Added feature extractor with GFMAP compatibility
GriffinBabe May 28, 2024
f7d09b9
fix: clean-up + updated dependencies
HansVRP May 29, 2024
63722e5
Added presto feature computer using GFMAP
GriffinBabe May 31, 2024
14ff604
Merge branch 'hv_mvp_inferenceUDF' into 44-translate-cropno-crop-infe…
GriffinBabe May 31, 2024
5ed426b
UDFs are passing and reformatting for repository
GriffinBabe May 31, 2024
b443e8b
Cleaned up more by deleting a few duplicate codes
GriffinBabe May 31, 2024
2add215
Merge branch '44-translate-cropno-crop-inference-udf-to-gfmap' of htt…
GriffinBabe May 31, 2024
3251919
Fixed conflicts
GriffinBabe May 31, 2024
7b7ca4d
Implemented changed request by kristof
GriffinBabe Jun 3, 2024
3faef72
make use of external dependency through whl
kvantricht Jun 3, 2024
aa423c6
Merge branch 'main' into 44-translate-cropno-crop-inference-udf-to-gfmap
kvantricht Jun 7, 2024
8723fae
Changed to work with new openeo way of handling dependencies
GriffinBabe Jun 7, 2024
34e4621
Merge branch '44-translate-cropno-crop-inference-udf-to-gfmap' of htt…
GriffinBabe Jun 7, 2024
df87509
Now working with dependency as zip file and presto code packed as whe…
GriffinBabe Jun 11, 2024
20746f4
Changed dependencies .zip file
GriffinBabe Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,18 @@ notebooks/S1A_IW_GRDH_1SDV_20191026T153410_20191026T153444_029631_035FDA_2640.SA
scripts/classification/tenpercent_sparse/.nfs00000000c35c9cfd00000035
download.zip
catboost_info/catboost_training.json

*.cbm
*.pt
*.onnx
*.nc
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip

.notebook-tests/
93 changes: 70 additions & 23 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Cropland mapping inference script, demonstrating the use of the GFMAP, Presto and WorldCereal classifiers in a first inference pipeline."""

import argparse
from pathlib import Path

import openeo
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import Backend, BackendContext, cdse_connection
from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap


class PrestoFeatureExtractor(PatchFeatureExtractor):
def __init__(self):
pass

def extract(self, image):
pass

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this still needed? onnxruntime is included in the worldcereal_deps.zip. And if it's still needed, shouldn't we also put it on s3?

Copy link
Contributor Author

@GriffinBabe GriffinBabe Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's much lighter than all the presto dependency so it's a waste of resources to install everything just for onnxruntime. Also it's a requirement in GFMAP for the ModelInference subclasses, until the pip dependency feature is more stable.

Indeed, could you upload it to the S3 @HansVRP ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that for each single UDF in the chain those dependencies get downloaded/unpacked again? That's also lots of overhead I guess?

Copy link
Contributor Author

@GriffinBabe GriffinBabe Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it's just onnxruntime that get's reinstalled (quite light), compared to the full UDF I don't think there is a large difference in execution time, even less with large spatial extents probably

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cached, they should not be reloaded each time.


if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand All @@ -27,12 +25,20 @@ def extract(self, image):
parser.add_argument("miny", type=float, help="Minimum Y coordinate (south)")
parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)")
parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)")
parser.add_argument(
"--epsg",
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=4326,
help="EPSG code for coordiante reference system.",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
parser.add_argument(
"output_folder", type=str, help="Path to folder where to save results."
"output_path",
type=Path,
help="Path to folder where to save the resulting NetCDF.",
)

args = parser.parse_args()
Expand All @@ -41,29 +47,70 @@ def extract(self, image):
miny = args.miny
maxx = args.maxx
maxy = args.maxy
epsg = args.epsg

start_date = args.start_date
end_date = args.end_date

spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy)
spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
temporal_extent = TemporalContext(start_date, end_date)

backend = BackendContext(Backend.CDSE)
backend_context = BackendContext(Backend.FED)

connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
input_cube = worldcereal_preprocessed_inputs_gfmap(
connection=cdse_connection(),
backend_context=backend,
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Start the job and download
job = input_cube.create_job(
title=f"Cropland inference BBOX: {minx} {miny} {maxx} {maxy}",
description="Cropland inference using WorldCereal, Presto and GFMAP classifiers",
out_format="NetCDF",
# Test feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

catboost_parameters = {}

classes = apply_model_inference(
model_inference_class=CroplandClassifier,
cube=features,
parameters=catboost_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

job.start_and_wait()
job.get_results().download_files(args.output_folder)
classes.execute_batch(
outputfile=args.output_path,
out_format="NetCDF",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very high. Need to flag this for further profiling. But for now we can leave it like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I have seen it is a potential overkill, 200 km square ran on half the amount of memory for me. @GriffinBabe do you have a job id for a job which needed more memory?

"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
},
)
111 changes: 111 additions & 0 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Feature computer GFMAP compatible to compute Presto embeddings."""

import xarray as xr
from openeo.udf import XarrayDataCube
from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor


class PrestoFeatureExtractor(PatchFeatureExtractor):
"""Feature extractor to use Presto model to compute per-pixel embeddings.
This will generate a datacube with 128 bands, each band representing a
feature from the Presto model.

Interesting UDF parameters:
- presto_url: A public URL to the Presto model file. A default Presto
version is provided if the parameter is left undefined.
- rescale_s1: Is specifically disabled by default, as the presto
dependencies already take care of the backscatter decompression. If
specified, should be set as `False`.
"""

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-temp-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
DEPENDENCY_NAME = "worldcereal_deps.zip"

GFMAP_BAND_MAPPING = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in fact a bit an overhead, if later on in presto-worldcereal we again remap: https://github.com/WorldCereal/presto-worldcereal/blob/main/presto/inference.py#L41

So we should consider to immediately use the presto naming here and remove the remapping in presto-worldcereal

"S2-L2A-B02": "B02",
"S2-L2A-B03": "B03",
"S2-L2A-B04": "B04",
"S2-L2A-B05": "B05",
"S2-L2A-B06": "B06",
"S2-L2A-B07": "B07",
"S2-L2A-B08": "B08",
"S2-L2A-B8A": "B8A",
"S2-L2A-B11": "B11",
"S2-L2A-B12": "B12",
"S1-SIGMA0-VH": "VH",
"S1-SIGMA0-VV": "VV",
"COP-DEM": "DEM",
"AGERA5-TMEAN": "temperature-mean",
"AGERA5-PRECIP": "precipitation-flux",
}

def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list:
import urllib.request
import zipfile
from pathlib import Path

# Downloads the wheel file
modelfile, _ = urllib.request.urlretrieve(
wheel_url, filename=Path.cwd() / Path(wheel_url).name
)
with zipfile.ZipFile(modelfile, "r") as zip_ref:
zip_ref.extractall(destination_dir)
return destination_dir

def output_labels(self) -> list:
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the output labels from this UDF, which is the output labels
of the presto embeddings"""
return [f"presto_ft_{i}" for i in range(128)]

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
import sys
from pathlib import Path

if self.epsg is None:
raise ValueError(
"EPSG code is required for Presto feature extraction, but was "
"not correctly initialized."
)
presto_model_url = self._parameters.get(
"presto_model_url", self.PRESTO_MODEL_URL
)
presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL)

# The below is required to avoid flipping of the result
# when running on OpenEO backend!
inarr = inarr.transpose("bands", "t", "x", "y")

# Change the band names
new_band_names = [
self.GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands
]
inarr = inarr.assign_coords(bands=new_band_names)

# Handle NaN values in Presto compatible way
inarr = inarr.fillna(65535)

# Unzip de dependencies on the backend
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

# Debug, print the dependency directory
self.logger.info(f"Dependency directory: {list(Path(deps_dir).iterdir())}")

from presto.inference import get_presto_features

self.logger.info("Extracting presto features")
features = get_presto_features(inarr, presto_model_url, self.epsg)
return features

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
# Disable S1 rescaling (decompression) by default
if parameters.get("rescale_s1", None) is None:
parameters.update({"rescale_s1": False})
return super()._execute(cube, parameters)
Loading
Loading