Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SenBench-AQ-NO2-S5P and SenBench-AQ-O3-S5P datasets from SentinelBench #2607

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,308 +3,311 @@

"""TorchGeo datasets."""

from .advance import ADVANCE
from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity
from .agrifieldnet import AgriFieldNet
from .airphen import Airphen
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet, BigEarthNetV2
from .biomassters import BioMassters
from .bright import BRIGHTDFC2025
from .cabuar import CaBuAr
from .caffe import CaFFe
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chabud import ChaBuD
from .chesapeake import (
Chesapeake,
ChesapeakeCVPR,
ChesapeakeDC,
ChesapeakeDE,
ChesapeakeMD,
ChesapeakeNY,
ChesapeakePA,
ChesapeakeVA,
ChesapeakeWV,
)
from .cloud_cover import CloudCoverDetection
from .cms_mangrove_canopy import CMSGlobalMangroveCanopy
from .cowc import COWC, COWCCounting, COWCDetection
from .cropharvest import CropHarvest
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCyclone
from .deepglobelandcover import DeepGlobeLandCover
from .dfc2022 import DFC2022
from .digital_typhoon import DigitalTyphoon
from .dl4gam import DL4GAMAlps
from .eddmaps import EDDMapS
from .enmap import EnMAP
from .enviroatlas import EnviroAtlas
from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError
from .esri2020 import Esri2020
from .etci2021 import ETCI2021
from .eudem import EUDEM
from .eurocrops import EuroCrops
from .eurosat import EuroSAT, EuroSAT100, EuroSATSpatial
from .fair1m import FAIR1M
from .fire_risk import FireRisk
from .forestdamage import ForestDamage
from .ftw import FieldsOfTheWorld
from .gbif import GBIF
from .geo import (
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
UnionDataset,
VectorDataset,
)
from .geonrw import GeoNRW
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .hyspecnet import HySpecNet11k
from .idtrees import IDTReeS
from .inaturalist import INaturalist
from .inria import InriaAerialImageLabeling
from .iobench import IOBench
from .l7irish import L7Irish
from .l8biome import L8Biome
from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
from .landsat import (
Landsat,
Landsat1,
Landsat2,
Landsat3,
Landsat4MSS,
Landsat4TM,
Landsat5MSS,
Landsat5TM,
Landsat7,
Landsat8,
Landsat9,
)
from .levircd import LEVIRCD, LEVIRCDBase, LEVIRCDPlus
from .loveda import LoveDA
from .mapinwild import MapInWild
from .mdas import MDAS
from .millionaid import MillionAID
from .mmearth import MMEarth
from .mmflood import MMFlood
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
from .nccm import NCCM
from .nlcd import NLCD
from .openbuildings import OpenBuildings
from .oscd import OSCD
from .pastis import PASTIS
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .prisma import PRISMA
from .quakeset import QuakeSet
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rwanda_field_boundary import RwandaFieldBoundary
from .satlas import SatlasPretrain
from .seasonet import SeasoNet
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel1, Sentinel2
from .skippd import SKIPPD
from .skyscript import SkyScript
from .so2sat import So2Sat
from .south_africa_crop_type import SouthAfricaCropType
from .south_america_soybean import SouthAmericaSoybean
from .spacenet import (
SpaceNet,
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet6,
SpaceNet7,
SpaceNet8,
)
from .splits import (
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)
from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12
from .ssl4eo_benchmark import SSL4EOLBenchmark
from .sustainbench_crop_yield import SustainBenchCropYield
from .treesatai import TreeSatAI
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
BoundingBox,
concat_samples,
merge_samples,
stack_samples,
unbind_samples,
)
from .vaihingen import Vaihingen2D
from .vhr10 import VHR10
from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture
from .xview import XView2
from .zuericrop import ZueriCrop
from .senbench_airquality_s5p import SenBenchAQNO2S5P, SenBenchAQO3S5P

Check failure on line 155 in torchgeo/datasets/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

torchgeo/datasets/__init__.py:6:1: I001 Import block is un-sorted or un-formatted

Check failure on line 155 in torchgeo/datasets/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/datasets/__init__.py:155:56: F401 `.senbench_airquality_s5p.SenBenchAQO3S5P` imported but unused; consider removing, adding to `__all__`, or using a redundant alias

__all__ = (
'ADVANCE',
'BRIGHTDFC2025',
'CDL',
'COWC',
'DFC2022',
'ETCI2021',
'EUDEM',
'FAIR1M',
'GBIF',
'GID15',
'LEVIRCD',
'MDAS',
'NAIP',
'NCCM',
'NLCD',
'OSCD',
'PASTIS',
'PRISMA',
'RESISC45',
'SEN12MS',
'SKIPPD',
'SSL4EO',
'SSL4EOL',
'SSL4EOS12',
'VHR10',
'AbovegroundLiveWoodyBiomassDensity',
'AgriFieldNet',
'Airphen',
'AsterGDEM',
'BeninSmallHolderCashews',
'BigEarthNet',
'BigEarthNetV2',
'BioMassters',
'BoundingBox',
'CMSGlobalMangroveCanopy',
'COWCCounting',
'COWCDetection',
'CV4AKenyaCropType',
'CaBuAr',
'CaFFe',
'CanadianBuildingFootprints',
'ChaBuD',
'Chesapeake',
'ChesapeakeCVPR',
'ChesapeakeDC',
'ChesapeakeDE',
'ChesapeakeMD',
'ChesapeakeNY',
'ChesapeakePA',
'ChesapeakeVA',
'ChesapeakeWV',
'CloudCoverDetection',
'CropHarvest',
'DL4GAMAlps',
'DatasetNotFoundError',
'DeepGlobeLandCover',
'DependencyNotFoundError',
'DigitalTyphoon',
'EDDMapS',
'EnMAP',
'EnviroAtlas',
'Esri2020',
'EuroCrops',
'EuroSAT',
'EuroSAT100',
'EuroSATSpatial',
'FieldsOfTheWorld',
'FireRisk',
'ForestDamage',
'GeoDataset',
'GeoNRW',
'GlobBiomass',
'HySpecNet11k',
'IDTReeS',
'INaturalist',
'IOBench',
'InriaAerialImageLabeling',
'IntersectionDataset',
'L7Irish',
'L8Biome',
'LEVIRCDBase',
'LEVIRCDPlus',
'LandCoverAI',
'LandCoverAI100',
'LandCoverAIBase',
'LandCoverAIGeo',
'Landsat',
'Landsat1',
'Landsat2',
'Landsat3',
'Landsat4MSS',
'Landsat4TM',
'Landsat5MSS',
'Landsat5TM',
'Landsat7',
'Landsat8',
'Landsat9',
'LoveDA',
'MMEarth',
'MMFlood',
'MapInWild',
'MillionAID',
'NASAMarineDebris',
'NonGeoClassificationDataset',
'NonGeoDataset',
'OpenBuildings',
'PatternNet',
'Potsdam2D',
'QuakeSet',
'RGBBandsMissingError',
'RasterDataset',
'ReforesTree',
'RwandaFieldBoundary',
'SSL4EOLBenchmark',
'SatlasPretrain',
'SeasoNet',
'SeasonalContrastS2',
'Sentinel',
'Sentinel1',
'Sentinel2',
'SkyScript',
'So2Sat',
'SouthAfricaCropType',
'SouthAmericaSoybean',
'SpaceNet',
'SpaceNet1',
'SpaceNet2',
'SpaceNet3',
'SpaceNet4',
'SpaceNet5',
'SpaceNet6',
'SpaceNet7',
'SpaceNet8',
'SustainBenchCropYield',
'TreeSatAI',
'TropicalCyclone',
'UCMerced',
'USAVars',
'UnionDataset',
'Vaihingen2D',
'VectorDataset',
'WesternUSALiveFuelMoisture',
'XView2',
'ZueriCrop',
'concat_samples',
'merge_samples',
'random_bbox_assignment',
'random_bbox_splitting',
'random_grid_cell_assignment',
'roi_split',
'stack_samples',
'time_series_split',
'unbind_samples',
'SenBenchAQNO2S5P',
'SenBenchNO2S5P',
)

Check failure on line 313 in torchgeo/datasets/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF022)

torchgeo/datasets/__init__.py:157:11: RUF022 `__all__` is not sorted
202 changes: 202 additions & 0 deletions torchgeo/datasets/senbench_airquality_s5p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import kornia as K

Check failure on line 1 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D100)

torchgeo/datasets/senbench_airquality_s5p.py:1:1: D100 Missing docstring in public module

Check failure on line 1 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/datasets/senbench_airquality_s5p.py:1:18: F401 `kornia` imported but unused
import torch
from torchgeo.datasets.geo import NonGeoDataset
import os
from collections.abc import Callable, Sequence

Check failure on line 5 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/datasets/senbench_airquality_s5p.py:5:39: F401 `collections.abc.Sequence` imported but unused
from torch import Tensor
import numpy as np
import rasterio
import cv2
from pyproj import Transformer
from datetime import date
from typing import TypeAlias, ClassVar

Check failure on line 12 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/datasets/senbench_airquality_s5p.py:12:31: F401 `typing.ClassVar` imported but unused
import pathlib

Check failure on line 13 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/datasets/senbench_airquality_s5p.py:13:8: F401 `pathlib` imported but unused

import logging

Check failure on line 15 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

torchgeo/datasets/senbench_airquality_s5p.py:1:1: I001 Import block is un-sorted or un-formatted

logging.getLogger("rasterio").setLevel(logging.ERROR)
Path: TypeAlias = str | os.PathLike[str]

class SenBenchAirQualityS5P(NonGeoDataset):
"""Parent class for SenBench-AQ-NO2-S5P and SenBench-AQ-O3-S5P datasets.

The SenBench-AQ-NO2-S5P and SenBench-AQ-O3-S5P datasets are level-3 datasets from the SentinelBench benchmark.
It contains Sentinel-5P NO2/O3 images and EEA NO2/O3 maps for the air pollutants regression task.
It supports both static (1 image / location, annual mean) and time series (~4 images / location, seasonal mean) mode, the former is used in the original benchmark.

Dataset features:
* task: dense regression
* # samples: 1480/493/494 (train/val/test, static mode)
* image resolution: 56x56 (GSD 1km)
* label resolution: 56x56 (GSD 1km)
* mode: annual (static) or seasonal (time series)
* modality: NO2 or O3

Dataset format:
* images: 1 band Sentinel-5P NO2/O3 images (GeoTIFF)
* labels: EEA NO2/O3 maps (GeoTIFF)

If you use this dataset in your research, please cite the following paper:

* To be released soon


"""

url = 'https://huggingface.co/datasets/wangyi111/SentinelBench/resolve/main/l3_airquality_s5p/airquality_s5p.zip'
splits = ('train', 'val', 'test')
split_fnames = {
'train': 'train.csv',
'val': 'val.csv',
'test': 'test.csv',
}

Check failure on line 52 in torchgeo/datasets/senbench_airquality_s5p.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF012)

torchgeo/datasets/senbench_airquality_s5p.py:48:20: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`

def __init__(
self,
root: Path = 'data',
split: str = 'train',
modality = 'no2', # or 'o3'
mode = 'annual', # or 'seasonal'
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:

self.root = root
self.transforms = transforms
self.download = download
#self.checksum = checksum

assert split in ['train', 'val', 'test']

self.modality = modality
self.mode = mode

if self.mode == 'annual':
mode_dir = 's5p_annual'
elif self.mode == 'seasonal':
mode_dir = 's5p_seasonal'

self.img_dir = os.path.join(root, modality, mode_dir)
self.label_dir = os.path.join(root, modality, 'label_annual')

self.split_csv = os.path.join(self.root, modality, self.split_fnames[split])
with open(self.split_csv, 'r') as f:
lines = f.readlines()
self.pids = []
for line in lines:
self.pids.append(line.strip())

self.reference_date = date(1970, 1, 1)
self.patch_area = (4*1)**2 # patchsize 4 pix, gsd 1km

def __len__(self):
return len(self.pids)

def __getitem__(self, index):

images, meta_infos = self._load_image(index)
label = self._load_target(index)
if self.mode == 'annual':
sample = {'image': images[0], 'groundtruth': label, 'meta': meta_infos[0]}
elif self.mode == 'seasonal':
sample = {'image': images, 'groundtruth': label, 'meta': meta_infos}

if self.transforms is not None:
sample = self.transforms(sample)

return sample


def _load_image(self, index):

pid = self.pids[index]
s5p_path = os.path.join(self.img_dir, pid)

img_fnames = os.listdir(s5p_path)
s5p_paths = []
for img_fname in img_fnames:
s5p_paths.append(os.path.join(s5p_path, img_fname))

imgs = []
meta_infos = []
for img_path in s5p_paths:
with rasterio.open(img_path) as src:
img = src.read(1)
img[np.isnan(img)] = 0
img = cv2.resize(img, (56,56), interpolation=cv2.INTER_CUBIC)
img = torch.from_numpy(img).float()
img = img.unsqueeze(0)

# get lon, lat
cx,cy = src.xy(src.height // 2, src.width // 2)
if src.crs.to_string() != 'EPSG:4326':
# convert to lon, lat
crs_transformer = Transformer.from_crs(src.crs, 'epsg:4326', always_xy=True)
lon, lat = crs_transformer.transform(cx,cy)
else:
lon, lat = cx, cy
# get time
img_fname = os.path.basename(img_path)
date_str = img_fname.split('_')[0][:10]
date_obj = date(int(date_str[:4]), int(date_str[5:7]), int(date_str[8:10]))
delta = (date_obj - self.reference_date).days
#meta_info = np.array([lon, lat, delta, self.patch_area]).astype(np.float32)
#meta_info = torch.from_numpy(meta_info)
meta_info = {
'lon': torch.tensor(lon),
'lat': torch.tensor(lat),
'delta-t': torch.tensor(delta), # days since 1970-01-01
'area-p': torch.tensor(self.patch_area), # ViT patch area in km^2
}

imgs.append(img)
meta_infos.append(meta_info)

if self.mode == 'seasonal':
# pad to 4 images if less than 4
while len(imgs) < 4:
imgs.append(img)
meta_infos.append(meta_info)

return imgs, meta_infos # return list of images and meta_infos

def _load_target(self, index):

pid = self.pids[index]
label_path = os.path.join(self.label_dir, pid+'.tif')

with rasterio.open(label_path) as src:
label = src.read(1)
label = cv2.resize(label, (56,56), interpolation=cv2.INTER_NEAREST) # 0-650
# label contains -inf
label[label<-1e10] = np.nan
label[label>1e10] = np.nan
label = torch.from_numpy(label.astype('float32'))

return label


class SenBenchAQNO2S5P(SenBenchAirQualityS5P):
"""SenBench-AQ-NO2-S5P dataset."""
def __init__(
self,
root: Path = 'data',
split: str = 'train',
mode = 'annual', # or 'seasonal'
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:
super().__init__(root, split, 'no2', mode, transforms, download)


class SenBenchAQO3S5P(SenBenchAirQualityS5P):
"""SenBench-AQ-O3-S5P dataset."""
def __init__(
self,
root: Path = 'data',
split: str = 'train',
mode = 'annual', # or 'seasonal'
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:
super().__init__(root, split, 'o3', mode, transforms, download)
Loading