diff --git a/poetry.lock b/poetry.lock index 2af37418d..5617a7da9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -904,6 +904,28 @@ dev = ["pre-commit"] doc = ["markdown-include", "mkdocs", "mkdocs-material", "mkdocstrings"] test = ["black", "django-stubs", "flake8", "isort", "mypy (==0.931)", "psycopg2-binary", "pytest", "pytest-asyncio", "pytest-cov", "pytest-django"] +[[package]] +name = "django-s3-file-field" +version = "1.0.1" +description = "A Django library for uploading files directly to AWS S3 or MinIO Storage from HTTP clients." +optional = false +python-versions = ">=3.8" +files = [ + {file = "django_s3_file_field-1.0.1-py3-none-any.whl", hash = "sha256:4a607367b8bacf4d4c76cbea9f172e7c98e15da1efb839a4cdcb48a5b41036c4"}, + {file = "django_s3_file_field-1.0.1.tar.gz", hash = "sha256:093675afbf29ba874fc02ac6a70c826919fab952689dbcb32848ae5c318b4ca7"}, +] + +[package.dependencies] +django = ">=3.2" +django-minio-storage = {version = ">=0.5", optional = true, markers = "extra == \"minio\""} +djangorestframework = "*" +minio = {version = ">=7", optional = true, markers = "extra == \"minio\""} + +[package.extras] +minio = ["django-minio-storage (>=0.5)", "minio (>=7)"] +pytest = ["pytest"] +s3 = ["boto3", "django-storages[s3] (>=1.14)"] + [[package]] name = "django-storages" version = "1.14.2" @@ -964,6 +986,20 @@ files = [ django = "*" typing-extensions = "*" +[[package]] +name = "djangorestframework" +version = "3.15.2" +description = "Web APIs for Django, made easy." +optional = false +python-versions = ">=3.8" +files = [ + {file = "djangorestframework-3.15.2-py3-none-any.whl", hash = "sha256:2b8871b062ba1aefc2de01f773875441a961fefbf79f5eed1e32b2f096944b20"}, + {file = "djangorestframework-3.15.2.tar.gz", hash = "sha256:36fe88cd2d6c6bec23dca9804bab2ba5517a8bb9d8f47ebc68981b56840107ad"}, +] + +[package.dependencies] +django = ">=4.2" + [[package]] name = "filelock" version = "3.13.1" @@ -1005,41 +1041,6 @@ files = [ [package.dependencies] python-dateutil = ">=2.7" -[[package]] -name = "fsspec" -version = "2024.3.1" -description = "File-system specification" -optional = false -python-versions = ">=3.8" -files = [ - {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, - {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, -] - -[package.extras] -abfs = ["adlfs"] -adl = ["adlfs"] -arrow = ["pyarrow (>=1)"] -dask = ["dask", "distributed"] -devel = ["pytest", "pytest-cov"] -dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] -fuse = ["fusepy"] -gcs = ["gcsfs"] -git = ["pygit2"] -github = ["requests"] -gs = ["gcsfs"] -gui = ["panel"] -hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] -libarchive = ["libarchive-c"] -oci = ["ocifs"] -s3 = ["s3fs"] -sftp = ["paramiko"] -smb = ["smbprotocol"] -ssh = ["paramiko"] -tqdm = ["tqdm"] - [[package]] name = "geojson" version = "3.1.0" @@ -1286,23 +1287,6 @@ files = [ {file = "iso3166-2.1.1.tar.gz", hash = "sha256:fcd551b8dda66b44e9f9e6d6bbbee3a1145a22447c0a556e5d0fb1ad1e491719"}, ] -[[package]] -name = "jinja2" -version = "3.1.3" -description = "A very fast and expressive template engine." -optional = false -python-versions = ">=3.7" -files = [ - {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, - {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, -] - -[package.dependencies] -MarkupSafe = ">=2.0" - -[package.extras] -i18n = ["Babel (>=2.7)"] - [[package]] name = "jmespath" version = "1.0.1" @@ -1537,24 +1521,6 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] -[[package]] -name = "networkx" -version = "3.2.1" -description = "Python package for creating and manipulating graphs and networks" -optional = false -python-versions = ">=3.9" -files = [ - {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, - {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, -] - -[package.extras] -default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] -developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] -doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] -test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] - [[package]] name = "nodeenv" version = "1.8.0" @@ -2321,6 +2287,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2328,8 +2295,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2346,6 +2321,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2353,6 +2329,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, diff --git a/pyproject.toml b/pyproject.toml index e2b1f39a7..7f405cd1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ apache-airflow-client = "^2.9.0" beautifulsoup4 = "^4.12.3" django-allauth = {extras = ["socialaccount"], version = "^0.63.2"} django-login-required-middleware = "^0.9.0" +django-s3-file-field = {extras = ["minio"], version = "^1.0.1"} [tool.poetry.group.dev.dependencies] django-stubs = "^4.2.7" diff --git a/rdwatch/core/admin.py b/rdwatch/core/admin.py index 8c09ec737..74d22637f 100644 --- a/rdwatch/core/admin.py +++ b/rdwatch/core/admin.py @@ -2,6 +2,7 @@ from rdwatch.core.models import ( ModelRun, + ModelRunUpload, Performer, Region, SatelliteFetching, @@ -123,3 +124,15 @@ class SiteObservationAdmin(admin.ModelAdmin): ) list_filter = ('timestamp',) raw_id_fields = ('siteeval', 'label', 'constellation', 'spectrum') + + +@admin.register(ModelRunUpload) +class ModelRunUploadAdmin(admin.ModelAdmin): + list_display = ( + 'id', + 'title', + 'performer', + 'region', + 'zipfile', + 'task_id', + ) diff --git a/rdwatch/core/migrations/0033_modelrunupload.py b/rdwatch/core/migrations/0033_modelrunupload.py new file mode 100644 index 000000000..888e20f78 --- /dev/null +++ b/rdwatch/core/migrations/0033_modelrunupload.py @@ -0,0 +1,59 @@ +# Generated by Django 5.0.8 on 2024-08-14 21:17 + +import uuid + +import s3_file_field.fields + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0032_siteevaluation_point_siteobservation_point_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='ModelRunUpload', + fields=[ + ( + 'id', + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ('title', models.CharField(max_length=1000)), + ( + 'private', + models.BooleanField( + default=False, + help_text='Whether this model run should be private', + ), + ), + ( + 'region', + models.CharField( + blank=True, + help_text='Override for the region this model run belongs to', + max_length=1000, + ), + ), + ( + 'performer', + models.CharField( + blank=True, + help_text='Override for the team that produced this evaluation', + max_length=1000, + ), + ), + ('zipfile', s3_file_field.fields.S3FileField()), + ( + 'task_id', + models.CharField(help_text='celery task ID', max_length=128), + ), + ], + ), + ] diff --git a/rdwatch/core/models/__init__.py b/rdwatch/core/models/__init__.py index ee79c7eb3..cd4fef68f 100644 --- a/rdwatch/core/models/__init__.py +++ b/rdwatch/core/models/__init__.py @@ -1,6 +1,7 @@ from . import lookups from .annotation_exports import AnnotationExport from .model_run import ModelRun +from .model_run_upload import ModelRunUpload from .performer import Performer from .region import Region from .satellite_fetching import SatelliteFetching @@ -12,6 +13,7 @@ 'AnnotationExport', 'lookups', 'ModelRun', + 'ModelRunUpload', 'Performer', 'Region', 'SiteEvaluation', diff --git a/rdwatch/core/models/model_run_upload.py b/rdwatch/core/models/model_run_upload.py new file mode 100644 index 000000000..dccbca3d4 --- /dev/null +++ b/rdwatch/core/models/model_run_upload.py @@ -0,0 +1,38 @@ +from uuid import uuid4 + +from s3_file_field import S3FileField + +from django.db import models +from django.db.models.signals import pre_delete +from django.dispatch import receiver + + +class ModelRunUpload(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + + title = models.CharField(max_length=1000) + private = models.BooleanField( + default=False, help_text='Whether this model run should be private' + ) + region = models.CharField( + max_length=1000, + blank=True, + help_text='Override for the region this model run belongs to', + ) + performer = models.CharField( + max_length=1000, + blank=True, + help_text='Shortcode override for the team that produced this evaluation', + ) + zipfile = S3FileField() + + task_id = models.CharField(max_length=256, help_text='Celery task ID') + + def __str__(self) -> str: + return f'' + + +@receiver(pre_delete, sender=ModelRunUpload) +def delete_zipfile(sender, instance, **kwargs): + if instance.zipfile: + instance.zipfile.delete(save=False) diff --git a/rdwatch/core/tasks/__init__.py b/rdwatch/core/tasks/__init__.py index b8dee0862..3de9fa981 100644 --- a/rdwatch/core/tasks/__init__.py +++ b/rdwatch/core/tasks/__init__.py @@ -6,8 +6,8 @@ import zipfile from collections.abc import Iterable from datetime import datetime, timedelta -from typing import Literal -from uuid import uuid4 +from typing import Literal, TypeVar +from uuid import UUID, uuid4 import cv2 import numpy as np @@ -18,7 +18,7 @@ from django_celery_results.models import TaskResult from more_itertools import ichunked from PIL import Image -from pydantic import UUID4 +from pydantic import UUID4, BaseModel, ValidationError from pyproj import Transformer from segment_anything import SamPredictor, sam_model_registry @@ -34,12 +34,19 @@ from rdwatch.core.models import ( AnnotationExport, ModelRun, + ModelRunUpload, + Performer, SatelliteFetching, SiteEvaluation, SiteImage, SiteObservation, ) from rdwatch.core.models.lookups import Constellation +from rdwatch.core.models.region import get_or_create_region + +from rdwatch.core.views.site_evaluation import get_site_model_feature_JSON +from rdwatch.core.schemas.region_model import RegionModel +from rdwatch.core.schemas.site_model import SiteModel from rdwatch.core.utils.images import ( fetch_boundbox_image, get_max_bbox, @@ -53,7 +60,6 @@ from rdwatch.core.utils.worldview_processed.raster_tile import ( get_worldview_processed_visual_bbox, ) -from rdwatch.core.views.site_evaluation import get_site_model_feature_JSON logger = logging.getLogger(__name__) # lowest time to use if time is null for observations @@ -719,3 +725,91 @@ def download_sam_model_if_not_exists(**kwargs): return f'Error downloading file: {e}' else: return f'File already exists at {file_path}' + + +ModelT = TypeVar('ModelT', bound=BaseModel) + + +def parse_model_json(ModelClass: type[ModelT], data: str | bytes) -> ModelT | None: + try: + return ModelClass.parse_raw(data) + except ValidationError: + return None + + +def process_model_run_upload(model_run_upload: ModelRunUpload): + # parse out the site and region models from the uploaded zipfile + site_models: list[SiteModel] = [] + region_models: list[RegionModel] = [] + + with model_run_upload.zipfile.open('rb') as fp, zipfile.ZipFile(fp, 'r') as zipfp: + for filename in zipfp.namelist(): + if not filename.endswith('.geojson'): + continue + + contents = zipfp.read(filename) + + model = parse_model_json(RegionModel, contents) + if model: + region_models.append(model) + continue + + model = parse_model_json(SiteModel, contents) + if model: + site_models.append(model) + continue + + # TODO invalid geojson: ignore or log? + + if len(site_models) == 0 or len(region_models) == 0: + raise ValueError('Did not receive any site or region models') + + all_region_ids: set[str] = set() + all_originators: set[str] = set() + for model in site_models: + all_region_ids.add(model.site_feature.properties.region_id) + all_originators.add(model.site_feature.properties.originator) + for model in region_models: + all_region_ids.add(model.region_feature.properties.region_id) + all_originators.add(model.region_feature.properties.originator) + + # TODO handle len(all_region_ids) > 1, len(all_originators) > 1 + if len(all_region_ids) == 0: + raise ValueError('No regions') + if len(all_originators) == 0: + raise ValueError('No originators') + + region_id = model_run_upload.region or next(iter(all_region_ids)) + performer_shortcode = model_run_upload.performer or next(iter(all_originators)) + + with transaction.atomic(): + # create a new ModelRun + region, _ = get_or_create_region(region_id) + performer, _ = Performer.objects.get_or_create( + short_code=performer_shortcode.upper(), + team_name=performer_shortcode, + ) + model_run = ModelRun.objects.create( + title=model_run_upload.title, + performer=performer, + region=region, + # TODO toggle for {'ground_truth': True/False}? + # TODO is this where the private toggle can be used? + parameters={}, + # TODO handle expiration_time, evaluation, evaluation_run, proposal? + ) + + for site_model in site_models: + SiteEvaluation.bulk_create_from_site_model(site_model, model_run) + for region_model in region_models: + SiteEvaluation.bulk_create_from_region_model(region_model, model_run) + + +@shared_task +def process_model_run_upload_task(upload_id: UUID): + model_run_upload = ModelRunUpload.objects.get(pk=upload_id) + + try: + process_model_run_upload(model_run_upload) + finally: + model_run_upload.delete() diff --git a/rdwatch/core/urls.py b/rdwatch/core/urls.py index 5448f2381..2213c261e 100644 --- a/rdwatch/core/urls.py +++ b/rdwatch/core/urls.py @@ -1,10 +1,11 @@ -from django.urls import path +from django.urls import include, path from . import views from .api import api urlpatterns = [ path('', api.urls), + path('s3-upload/', include('s3_file_field.urls')), path('satellite-image/timestamps', views.satelliteimage_time_list), path('satellite-image/all-timestamps', views.all_satellite_timestamps), path( diff --git a/rdwatch/core/views/model_run.py b/rdwatch/core/views/model_run.py index 1f3d120da..dbc50d9fe 100644 --- a/rdwatch/core/views/model_run.py +++ b/rdwatch/core/views/model_run.py @@ -3,6 +3,7 @@ from celery.result import AsyncResult from ninja import Field, FilterSchema, Query, Schema +from ninja.errors import ValidationError from ninja.pagination import PageNumberPagination, RouterPaginated, paginate from ninja.schema import validator from ninja.security import APIKeyHeader @@ -10,7 +11,9 @@ from django.conf import settings from django.contrib.postgres.aggregates import JSONBAgg +from django.core import signing from django.core.cache import cache +from django.core.files.storage import default_storage from django.db import transaction from django.db.models import ( Avg, @@ -36,6 +39,7 @@ from rdwatch.core.models import ( AnnotationExport, ModelRun, + ModelRunUpload, Performer, Region, SatelliteFetching, @@ -48,6 +52,7 @@ cancel_generate_images_task, download_annotations, generate_site_images_for_evaluation_run, + process_model_run_upload_task, ) from rdwatch.core.views.performer import PerformerSchema from rdwatch.core.views.site_observation import GenerateImagesSchema @@ -163,6 +168,14 @@ class ModelRunListSchema(Schema): mode: Literal['batch', 'incremental'] | None = None +class ModelRunUploadSchema(Schema): + title: str + region: str | None = None + performer: str | None = None + zipfileKey: str + private: bool = False + + def get_queryset(): # Subquery to count unique SiteEvaluations # with proposal='PROPOSAL' for each ModelRun @@ -637,3 +650,36 @@ def get_downloaded_annotations(request: HttpRequest, id: UUID4, task_id: str): 'Content-Disposition' ] = f'attachment; filename="{annotation_export.name}.zip"' return response + + +@router.post('/start_upload_processing') +def start_model_run_upload_processing( + request: HttpRequest, upload_data: ModelRunUploadSchema +): + zipfile_upload = signing.loads(upload_data.zipfileKey) + + if not upload_data.title.strip(): + raise ValidationError('Invalid model run title') + + with transaction.atomic(): + upload = ModelRunUpload.objects.create( + title=upload_data.title, + region=upload_data.region, + performer=upload_data.performer, + zipfile=zipfile_upload['object_key'], + private=upload_data.private, + ) + if not default_storage.exists(upload.zipfile.name): + raise ValidationError('Invalid file name provided') + + task = process_model_run_upload_task.delay(upload.id) + upload.task_id = task.id + upload.save() + + return task.id + + +@router.get('/upload_status/{task_id}') +def model_run_upload_status(request: HttpRequest, task_id: str): + result = AsyncResult(task_id) + return result.status diff --git a/rdwatch/settings.py b/rdwatch/settings.py index a0d104716..9ca2f210e 100644 --- a/rdwatch/settings.py +++ b/rdwatch/settings.py @@ -56,6 +56,7 @@ def INSTALLED_APPS(self): 'allauth.socialaccount', 'allauth.socialaccount.providers.gitlab', 'rdwatch.core.apps.RDWatchConfig', + 's3_file_field', ] if 'RDWATCH_POSTGRESQL_SCORING_URI' in os.environ: base_applications.append('rdwatch.scoring.apps.ScoringConfig') diff --git a/vue/package-lock.json b/vue/package-lock.json index 6bc40535f..396206e72 100644 --- a/vue/package-lock.json +++ b/vue/package-lock.json @@ -13,6 +13,7 @@ "@turf/turf": "7.0.0-alpha.113", "@types/mapbox__mapbox-gl-draw": "^1.4.6", "canvas-capture": "^2.1.1", + "django-s3-file-field": "^1.0.1", "lodash": "^4.17.21", "maplibre-gl": "^2.1.9", "npyjs": "^0.6.0", @@ -3908,6 +3909,11 @@ "node": ">=8" } }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" + }, "node_modules/autoprefixer": { "version": "10.4.19", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.19.tgz", @@ -3959,6 +3965,16 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/axios": { + "version": "1.7.3", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.3.tgz", + "integrity": "sha512-Ar7ND9pU99eJ9GpoGQKhKf58GpUOgnzuaB7ueNQ5BMi0p+LZ5oaEnfF999fAArcTIBwXTCHAmGcHOZJaWPq9Nw==", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.0", + "proxy-from-env": "^1.1.0" + } + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -4226,6 +4242,17 @@ "integrity": "sha512-jeC1axXpnb0/2nn/Y1LPuLdgXBLH7aDcHu4KEKfqw3CUhX7ZpfBSlPKyqXE6btIgEzfWtrX3/tyBCaCvXvMkOw==", "dev": true }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/commander": { "version": "10.0.1", "resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz", @@ -4535,6 +4562,14 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/dir-glob": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", @@ -4547,6 +4582,14 @@ "node": ">=8" } }, + "node_modules/django-s3-file-field": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/django-s3-file-field/-/django-s3-file-field-1.0.1.tgz", + "integrity": "sha512-va0BHToHhEKB8+hguj7vQdbdVNYcmlatPC/m/eszzOZNrNeKe3rPAJphqmTjXaYmBD4RedqlEM/IgsSQFU5Sjw==", + "dependencies": { + "axios": "^1.6.8" + } + }, "node_modules/doctrine": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", @@ -5567,6 +5610,25 @@ "integrity": "sha512-X8cqMLLie7KsNUDSdzeN8FYK9rEt4Dt67OsG/DNGnYTSDBG4uFAJFBnUeiV+zCVAvwFy56IjM9sH51jVaEhNxw==", "dev": true }, + "node_modules/follow-redirects": { + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, "node_modules/for-each": { "version": "0.3.3", "resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.3.tgz", @@ -5591,6 +5653,19 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/form-data": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", + "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, "node_modules/fraction.js": { "version": "4.3.7", "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", @@ -6908,6 +6983,25 @@ "node": ">=10" } }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -7527,6 +7621,11 @@ "resolved": "https://registry.npmjs.org/protocol-buffers-schema/-/protocol-buffers-schema-3.6.0.tgz", "integrity": "sha512-TdDRD+/QNdrCGCE7v8340QyuXd4kIWIgapsE2+n/SaGiSSbomYl4TjHlvIoCWRpE7wFt02EpB35VVA2ImcBVqw==" }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==" + }, "node_modules/punycode": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", diff --git a/vue/package.json b/vue/package.json index 510049f65..5ed30b70b 100644 --- a/vue/package.json +++ b/vue/package.json @@ -24,6 +24,7 @@ "@turf/turf": "7.0.0-alpha.113", "@types/mapbox__mapbox-gl-draw": "^1.4.6", "canvas-capture": "^2.1.1", + "django-s3-file-field": "^1.0.1", "lodash": "^4.17.21", "maplibre-gl": "^2.1.9", "npyjs": "^0.6.0", diff --git a/vue/src/client/services/ApiService.ts b/vue/src/client/services/ApiService.ts index e027a6df6..e0dad9c19 100644 --- a/vue/src/client/services/ApiService.ts +++ b/vue/src/client/services/ApiService.ts @@ -185,6 +185,14 @@ export interface RegionUpload { } +export interface ModelRunUpload { + title: string; + region: string | null | undefined; + performer: string | null | undefined; + zipfileKey: string; + private: boolean; +} + type ApiPrefix = '/api' | '/api/scoring'; export class ApiService { @@ -594,6 +602,21 @@ export class ApiService { } + public static postModelRunUpload(data: ModelRunUpload): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: `${this.getApiPrefix()}/model-runs/start_upload_processing`, + body: { ...data }, + }); + } + + public static getModelRunUploadTaskStatus(taskId: string): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: `${this.getApiPrefix()}/model-runs/upload_status/${taskId}`, + }); + } + public static getSiteImageEmbeddingStatus(id: number, uuid: string): CancelablePromise<{state: string, status: string}> { return __request(OpenAPI, { method: "GET", diff --git a/vue/src/components/SideBar.vue b/vue/src/components/SideBar.vue index 14d7f0500..e2ca398ad 100644 --- a/vue/src/components/SideBar.vue +++ b/vue/src/components/SideBar.vue @@ -15,6 +15,7 @@ import type { Ref } from "vue"; import { changeTime } from "../interactions/timeStepper"; import { useRoute } from "vue-router"; import ModeSelector from './ModeSelector.vue'; +import UploadModelRun from './UploadModelRun.vue'; const timemin = ref(Math.floor(new Date(0).valueOf() / 1000)); @@ -183,6 +184,14 @@ const satelliteLoadingColor = computed(() => { {{ state.appVersion }} + + + +import { reactive, ref } from "vue"; +import S3FileFieldClient, { + S3FileFieldResultState, +} from "django-s3-file-field"; +import { ApiService } from "../client"; + +const s3ffClient = new S3FileFieldClient({ + baseUrl: `${ApiService.getApiPrefix()}/s3-upload/`, + apiConfig: { + // django csrf token handling + xsrfCookieName: "csrftoken", + xsrfHeaderName: "X-CSRFToken", + }, +}); + +const modelRun = reactive({ + title: "", + region: "", + performer: "", + private: false, +}); + +const uploadDialog = ref(false); +const validForm = ref(false); +const uploadFile = ref(); +const uploadLoading = ref(false); +const uploadError = ref(); +// Infinity means indeterminate +const uploadProgress = ref(0); + +const QUERY_TASK_POLL_DELAY = 1000; // msec +// from celery.states.READY_STATES +const CELERY_READY_STATES = ["FAILURE", "SUCCESS", "REVOKED"]; + +function openUploadDialog() { + uploadDialog.value = true; + uploadFile.value = undefined; + uploadError.value = undefined; +} + +async function untilTaskReady(taskId: string) { + return new Promise((resolve) => { + async function queryTask() { + const result = await ApiService.getModelRunUploadTaskStatus(taskId); + + if (CELERY_READY_STATES.includes(result)) { + resolve(result); + } else { + setTimeout(() => { + queryTask(); + }, QUERY_TASK_POLL_DELAY); + } + } + + queryTask(); + }); +} + +async function upload() { + try { + uploadLoading.value = true; + uploadError.value = undefined; + uploadProgress.value = Infinity; + + const file = Array.isArray(uploadFile.value) + ? uploadFile.value[0] + : uploadFile.value; + if (!file) return; + + const uploadResult = await s3ffClient.uploadFile( + file, + "core.ModelRunUpload.zipfile" + ); + if (uploadResult.state !== S3FileFieldResultState.Successful) { + const status = ["was aborted", "", "errored"][uploadResult.state]; + throw new Error(`File upload ${status}`); + } + + const taskId = await ApiService.postModelRunUpload({ + title: modelRun.title.trim(), + region: modelRun.region.trim(), + performer: modelRun.performer.trim(), + zipfileKey: uploadResult.value, + private: modelRun.private, + }); + + const taskState = await untilTaskReady(taskId); + if (taskState !== "SUCCESS") { + throw new Error(`Encountered non-success task state: ${taskState}`); + } + + uploadDialog.value = false; + } catch (err) { + uploadError.value = err; + } finally { + uploadLoading.value = false; + } +} + +function validateTitle(title: string) { + if (title.trim().length === 0) return "Must add a title"; + return true; +} + +function validateFile(files: File[]) { + if (files.length === 0) return "Must provide a model run file"; + return true; +} + + +