Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SenBench-LC100Cls-S3 dataset and SenBench-LC100Seg-S3 dataset from SentinelBench #2605

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions tests/data/senbench_cloud_s3/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import numpy as np
import rasterio
from rasterio.transform import Affine
from datetime import datetime, timedelta

Check failure on line 5 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/data/senbench_cloud_s3/data.py:1:1: I001 Import block is un-sorted or un-formatted

def generate_fake_dataset(root_dir='data', num_train=2, num_val=1, num_test=1):

Check failure on line 7 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

tests/data/senbench_cloud_s3/data.py:7:5: ANN201 Missing return type annotation for public function `generate_fake_dataset`

Check failure on line 7 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:7:27: ANN001 Missing type annotation for function argument `root_dir`

Check failure on line 7 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:7:44: ANN001 Missing type annotation for function argument `num_train`

Check failure on line 7 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:7:57: ANN001 Missing type annotation for function argument `num_val`

Check failure on line 7 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:7:68: ANN001 Missing type annotation for function argument `num_test`
"""Generates a fake dataset for testing the SenBenchCloudS3 dataset class.

Args:
root_dir (str): Root directory where the fake dataset will be created.
num_train (int): Number of training samples.
num_val (int): Number of validation samples.
num_test (int): Number of test samples.
"""
# Create directories
s3_olci_dir = os.path.join(root_dir, 's3_olci')
cloud_multi_dir = os.path.join(root_dir, 'cloud_multi')
cloud_binary_dir = os.path.join(root_dir, 'cloud_binary')
os.makedirs(s3_olci_dir, exist_ok=True)
os.makedirs(cloud_multi_dir, exist_ok=True)
os.makedirs(cloud_binary_dir, exist_ok=True)

# Generate filename components
start_date = datetime(2020, 1, 1)

def generate_samples(num_samples, offset=0):

Check failure on line 27 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN202)

tests/data/senbench_cloud_s3/data.py:27:9: ANN202 Missing return type annotation for private function `generate_samples`

Check failure on line 27 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:27:26: ANN001 Missing type annotation for function argument `num_samples`

Check failure on line 27 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/data/senbench_cloud_s3/data.py:27:39: ANN001 Missing type annotation for function argument `offset`
samples = []
for i in range(num_samples):
current_date = start_date + timedelta(days=offset+i)
date_str = current_date.strftime("%Y%m%d")
fname = f"S3_{i+offset:04d}____{date_str}_000000.tif"
samples.append(fname)
return samples

# Generate sample filenames for each split
# Create sample lists with sequential dates
train_samples = generate_samples(num_train, 0)
val_samples = generate_samples(num_val, num_train)
test_samples = generate_samples(num_test, num_train + num_val)

# Write CSV files
def write_csv(split, samples):

Check failure on line 43 in tests/data/senbench_cloud_s3/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN202)

tests/data/senbench_cloud_s3/data.py:43:9: ANN202 Missing return type annotation for private function `write_csv`
csv_path = os.path.join(root_dir, f'{split}.csv')
with open(csv_path, 'w') as f:
f.write('\n'.join(samples))

write_csv('train', train_samples)
write_csv('val', val_samples)
write_csv('test', test_samples)

# Generate all samples (train + val + test)
all_samples = train_samples + val_samples + test_samples

# Rasterio parameters
height = width = 256
transform = Affine.identity() # Identity transform for simplicity
crs = 'EPSG:4326' # WGS84 coordinate system

for sample in all_samples:
# Generate fake Sentinel-3 OLCI image (21 bands)
img_path = os.path.join(s3_olci_dir, sample)
with rasterio.open(
img_path,
'w',
driver='GTiff',
height=height,
width=width,
count=21,
dtype=np.float32,
transform=transform,
crs=crs
) as dst:
for band in range(1, 22):
data = np.random.rand(height, width).astype(np.float32)
dst.write(data, band)

# Generate multi-class cloud mask (values 0-5)
multi_path = os.path.join(cloud_multi_dir, sample)
with rasterio.open(
multi_path,
'w',
driver='GTiff',
height=height,
width=width,
count=1,
dtype=np.uint8,
transform=transform,
crs=crs
) as dst:
data = np.random.randint(0, 6, (height, width), dtype=np.uint8)
dst.write(data, 1)

# Generate binary cloud mask (values 0-2)
binary_path = os.path.join(cloud_binary_dir, sample)
with rasterio.open(
binary_path,
'w',
driver='GTiff',
height=height,
width=width,
count=1,
dtype=np.uint8,
transform=transform,
crs=crs
) as dst:
data = np.random.randint(0, 3, (height, width), dtype=np.uint8)
dst.write(data, 1)

if __name__ == '__main__':
generate_fake_dataset(root_dir='./senbench_cloud_s3')
print("Fake dataset generated successfully.")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/senbench_cloud_s3/senbench_cloud_s3/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
S3_0003____20200104_000000.tif
2 changes: 2 additions & 0 deletions tests/data/senbench_cloud_s3/senbench_cloud_s3/train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
S3_0000____20200101_000000.tif
S3_0001____20200102_000000.tif
1 change: 1 addition & 0 deletions tests/data/senbench_cloud_s3/senbench_cloud_s3/val.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
S3_0002____20200103_000000.tif
10 changes: 10 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@
from .xview import XView2
from .zuericrop import ZueriCrop

from .senbench_cloud_s3 import SenBenchCloudS3
from .senbench_lc100cls_s3 import SenBenchLC100ClsS3
from .senbench_lc100seg_s3 import SenBenchLC100SegS3
from .sentinelbench import SentinelBench


__all__ = (
'ADVANCE',
'BRIGHTDFC2025',
Expand Down Expand Up @@ -307,4 +313,8 @@
'stack_samples',
'time_series_split',
'unbind_samples',
'SenBenchCloudS3',
'SenBenchLC100ClsS3',
'SenBenchLC100SegS3',
'SentinelBench',
)
181 changes: 181 additions & 0 deletions torchgeo/datasets/senbench_cloud_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import torch
from torchgeo.datasets.geo import NonGeoDataset
import os
from collections.abc import Callable, Sequence
from torch import Tensor
import numpy as np
import rasterio
from pyproj import Transformer
from datetime import date
from typing import TypeAlias, ClassVar
from .utils import Path, download_and_extract_archive, extract_archive

import logging

logging.getLogger("rasterio").setLevel(logging.ERROR)

class SenBenchCloudS3(NonGeoDataset):
"""SenBench-Cloud-S3 dataset.

The SenBench-Cloud-S3 dataset is a level-1 dataset from the SentinelBench benchmark.
It contains Sentinel-3 OLCI images, multi-class cloud masks, and binary cloud masks for the cloud segmentation task.

Dataset features:

* task: semantic segmentation
* # samples: 1197/399/399 (train/val/test)
* image resolution: 256x256
* # classes: 5 (multi-class) / 2 (binary)

Dataset format:

* images: 21-band Sentinel-3 OLCI images (GeoTIFF)
* labels: multi-class cloud masks (GeoTIFF)
* binary_labels: binary cloud masks (GeoTIFF)

If you use this dataset in your research, please cite the following paper:

* To be released soon


"""
url = 'https://huggingface.co/datasets/wangyi111/SentinelBench/resolve/main/l3_biomass_s3/biomass_s3olci.zip'

splits = ('train', 'val', 'test')

split_filenames = {
'train': 'train.csv',
'val': 'val.csv',
'test': 'test.csv',
}

all_band_names = (
'Oa01_radiance', 'Oa02_radiance', 'Oa03_radiance', 'Oa04_radiance', 'Oa05_radiance', 'Oa06_radiance', 'Oa07_radiance',
'Oa08_radiance', 'Oa09_radiance', 'Oa10_radiance', 'Oa11_radiance', 'Oa12_radiance', 'Oa13_radiance', 'Oa14_radiance',
'Oa15_radiance', 'Oa16_radiance', 'Oa17_radiance', 'Oa18_radiance', 'Oa19_radiance', 'Oa20_radiance', 'Oa21_radiance',
)

all_band_scale = (
0.0139465,0.0133873,0.0121481,0.0115198,0.0100953,0.0123538,0.00879161,
0.00876539,0.0095103,0.00773378,0.00675523,0.0071996,0.00749684,0.0086512,
0.00526779,0.00530267,0.00493004,0.00549962,0.00502847,0.00326378,0.00324118)

rgb_bands = ('Oa08_radiance', 'Oa06_radiance', 'Oa04_radiance')

Cls_index_binary = {
'invalid': 0, # --> 255 should be ignored during training
'clear': 1, # --> 0
'cloud': 2, # --> 1
}

Cls_index_multi = {
'invalid': 0, # --> 255 should be ignored during training
'clear': 1, # --> 0
'cloud-sure': 2, # --> 1
'cloud-ambiguous': 3, # --> 2
'cloud shadow': 4, # --> 3
'snow and ice': 5, # --> 4
}

def __init__(
self,
root: Path = 'data',
split: str = 'train',
bands: Sequence[str] = all_band_names,
mode = 'multi',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:

self.root = root
self.transforms = transforms
self.download = download

assert split in ['train', 'val', 'test']

self.bands = bands
self.band_indices = [(self.all_band_names.index(b)+1) for b in bands if b in self.all_band_names]

self.mode = mode
self.img_dir = os.path.join(self.root, 's3_olci')
self.label_dir = os.path.join(self.root, 'cloud_'+mode)

self.split_csv = os.path.join(self.root, self.split_filenames[split])
self.fnames = []
with open(self.split_csv, 'r') as f:
lines = f.readlines()
for line in lines:
fname = line.strip()
self.fnames.append(fname)

self.reference_date = date(1970, 1, 1)
self.patch_area = (8*300/1000)**2 # patchsize 8 pix, gsd 300m

def __len__(self):
return len(self.fnames)

def __getitem__(self, index):

images, meta_infos = self._load_image(index)
label = self._load_target(index)
sample = {'image': images, 'mask': label, 'meta': meta_infos}

if self.transforms is not None:
sample = self.transforms(sample)

return sample


def _load_image(self, index):

fname = self.fnames[index]
s3_path = os.path.join(self.img_dir, fname)

with rasterio.open(s3_path) as src:
img = src.read(self.band_indices)
img[np.isnan(img)] = 0
chs = []
for b in range(21):
ch = img[b]*self.all_band_scale[b]
chs.append(ch)
img = np.stack(chs)
img = torch.from_numpy(img).float()

# get lon, lat
cx,cy = src.xy(src.height // 2, src.width // 2)
if src.crs.to_string() != 'EPSG:4326':
# convert to lon, lat
crs_transformer = Transformer.from_crs(src.crs, 'epsg:4326', always_xy=True)
lon, lat = crs_transformer.transform(cx,cy)
else:
lon, lat = cx, cy
# get time
img_fname = os.path.basename(s3_path)
date_str = img_fname.split('____')[1][:8]
date_obj = date(int(date_str[:4]), int(date_str[4:6]), int(date_str[6:8]))
delta = (date_obj - self.reference_date).days
# this is what CopernicusFM requires
#meta_info = np.array([lon, lat, delta, self.patch_area]).astype(np.float32)
#meta_info = torch.from_numpy(meta_info)
# this is more general
meta_info = {
'lon': torch.tensor(lon),
'lat': torch.tensor(lat),
'delta-t': torch.tensor(delta), # days since 1970-01-01
'area-p': torch.tensor(self.patch_area), # ViT patch area in km^2
}

return img, meta_info

def _load_target(self, index):

fname = self.fnames[index]
label_path = os.path.join(self.label_dir, fname)

with rasterio.open(label_path) as src:
label = src.read(1)
label[label==0] = 256
label = label - 1
labels = torch.from_numpy(label).long()

return labels
Loading
Loading