Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of other sensors #134

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions graph_weather/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main import for the complete models"""

from .data.nnjai_wrapp import SensorDataset, collate_fn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should actually import this here, since we have nnja as an optional dependency. I think to import this dataset, like in the tests, we should just do from graph_weather.data.nnja_ai import SensorDataset for example. I also renamed this on the main branch, but could you update the nnjai_wrapp.py to be named nnja_ai as I think that is maybe more consistent naming.

from .models.analysis import GraphWeatherAssimilator
from .models.forecast import GraphWeatherForecaster
2 changes: 2 additions & 0 deletions graph_weather/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Dataloaders and data processing utilities"""

from .nnjai_wrapp import SensorDataset, collate_fn
106 changes: 106 additions & 0 deletions graph_weather/data/nnjai_wrapp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this file as its not used now?

A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS

The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and
variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata.
"""

import numpy as np
import torch
from torch.utils.data import Dataset

try:
from nnja import DataCatalog
except ImportError:
print(
"NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`"
)


class SensorDataset(Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class SensorDataset(Dataset):
class NNJADataset(Dataset):

We might (probably will, as I am looking into adding other, non NNJA sensors) have other sensors that won't fit in this dataset. I think naming this just NNJADataset then is more descriptive and easier to parse where this data is coming from.

"""A custom PyTorch Dataset for handling various sensor data."""

def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"):
"""Initialize the dataset loader for various sensors.

Args:
dataset_name: Name of the dataset to load.
time: Specific timestamp to filter the data.
primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON).
additional_variables: List of additional variables to include in metadata.
sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS)
"""
self.dataset_name = dataset_name
self.time = time
self.primary_descriptors = primary_descriptors
self.additional_variables = additional_variables
self.sensor_type = sensor_type # New argument for selecting sensor type

# Load data catalog and dataset
self.catalog = DataCatalog(skip_manifest=True)
self.dataset = self.catalog[self.dataset_name]
self.dataset.load_manifest()

if self.sensor_type == "AMSU":
self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)]
)
elif self.sensor_type == "ATMS":
self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)]
)
elif self.sensor_type == "MHS":
self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)]
)
elif self.sensor_type == "IASI":
self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)]
)
elif self.sensor_type == "CrIS":
self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)]
)
else:
raise ValueError(f"Unsupported sensor type: {self.sensor_type}")

self.dataframe = self.dataset.load_dataset(engine="pandas")

for col in primary_descriptors:
if col not in self.dataframe.columns:
raise ValueError(f"The dataset must include a '{col}' column.")

self.metadata_columns = [
col for col in self.dataframe.columns if col not in self.primary_descriptors
]

def __len__(self):
"""Return the total number of samples in the dataset."""
return len(self.dataframe)

def __getitem__(self, index):
"""Return the observation and metadata for a given index."""
row = self.dataframe.iloc[index]
time = row["OBS_TIMESTAMP"].timestamp()
latitude = row["LAT"]
longitude = row["LON"]
metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32)

return {
"timestamp": torch.tensor(time, dtype=torch.float32),
"latitude": torch.tensor(latitude, dtype=torch.float32),
"longitude": torch.tensor(longitude, dtype=torch.float32),
"metadata": torch.from_numpy(metadata),
}


def collate_fn(batch):
"""Custom collate function to handle batching of dictionary data.

Args:
batch: List of dictionaries from __getitem__

Returns:
Single dictionary with batched tensors
"""
return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()}
166 changes: 166 additions & 0 deletions tests/test_nnjai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Tests for the nnjai_wrapp module in the graph_weather package.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tests for the nnjai_wrapp module in the graph_weather package.
Tests for the nnja_ai module in the graph_weather package.


This file contains unit tests for AMSUDataset and collate_fn functions.
"""

from datetime import datetime
from unittest.mock import MagicMock, patch

import pytest
import torch

from graph_weather.data.nnjai_wrapp import SensorDataset, collate_fn

@pytest.fixture
def mock_datacatalog():
"""
Fixture to mock the DataCatalog for unit tests to avoid actual data loading.
This mock provides a mock dataset with predefined columns and values.
"""
with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock:
# Mock dataset structure
mock_df = MagicMock()
mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"]

# Define a mock row
class MockRow:
def __getitem__(self, key):
data = {
"OBS_TIMESTAMP": datetime.now(),
"LAT": 45.0,
"LON": -120.0,
"TMBR_00001": 250.0,
"TMBR_00002": 260.0,
}
return data.get(key, None)

# Configure mock dataset
mock_row = MockRow()
mock_df.iloc = MagicMock()
mock_df.iloc.__getitem__.return_value = mock_row
mock_df.__len__.return_value = 100

mock_dataset = MagicMock()
mock_dataset.load_dataset.return_value = mock_df
mock_dataset.sel.return_value = mock_dataset
mock_dataset.load_manifest = MagicMock()

mock.return_value.__getitem__.return_value = mock_dataset
yield mock


def test_sensor_dataset(mock_datacatalog):
"""
Test the SensorDataset class to ensure proper data loading and tensor structure for different sensors.
"""
# Test for AMSU dataset
dataset_name = "amsu-1bamua-NC021023"
time = datetime(2021, 1, 1, 0, 0) # Using datetime object instead of string
primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"]
additional_variables = ["TMBR_00001", "TMBR_00002"]
dataset = SensorDataset(dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests still pass? I think they need to be updated for the new names. Ideally, there should be tests for all the different sensor types as well in NNJADataset. You should be able to do that with pytest.parameterize I think.


# Test dataset length
assert len(dataset) > 0, "Dataset should not be empty."

item = dataset[0]
expected_keys = {"timestamp", "latitude", "longitude", "metadata"}
assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected."

# Validate tensor properties
assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor."
assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32."
assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor."
assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor."
assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32."
assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor."
assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor."
assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32."
assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor."
assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor."
assert item["metadata"].shape == (len(additional_variables),), f"Metadata shape mismatch. Expected ({len(additional_variables)},)."
assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32."


def test_collate_function():
"""
Test the collate_fn function to ensure proper batching of dataset items.
"""
# Mock a batch of items
batch_size = 4
metadata_size = 2
mock_batch = [
{
"timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32),
"latitude": torch.tensor(45.0, dtype=torch.float32),
"longitude": torch.tensor(-120.0, dtype=torch.float32),
"metadata": torch.randn(metadata_size, dtype=torch.float32),
}
for _ in range(batch_size)
]

# Collate the batch
batched = collate_fn(mock_batch)

# Validate batched shapes and types
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch."
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch."
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch."
assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch."
assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch."
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch."
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch."
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch."


def test_sensor_datasets(mock_datacatalog):
"""
Test various sensor datasets (AMSU-A, ATMS, MHS, IASI, CrIS) to ensure they load properly
and print the relevant information.
"""
# Define datasets and associated parameters for different sensors
sensors = [
{"name": "amsu-1bamua-NC021023", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "AMSU"},
{"name": "atms-atms-NC021203", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "ATMS"},
{"name": "mhs-1bmhs-NC021027", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "MHS"},
{"name": "iasi-mtiasi-NC021241", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["IASIL1CB"], "sensor_type": "IASI"},
{"name": "cris-crisf4-NC021206", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["SRAD01_00001", "SRAD01_00002"], "sensor_type": "CrIS"}
]

# Loop through each sensor and load the dataset
for sensor in sensors:
print(f"\nTesting sensor: {sensor['name']}")

# Create the dataset instance
dataset = SensorDataset(sensor['name'], sensor['time'], sensor['primary_descriptors'], sensor['additional_variables'], sensor_type=sensor['sensor_type'])

# Print dataset length
print(f"Dataset length for {sensor['name']}: {len(dataset)}")

# Retrieve and print the first item
item = dataset[0]
print(f"First item from {sensor['name']}:")
print(item)

# Ensure the dataset item structure is correct
expected_keys = {"timestamp", "latitude", "longitude", "metadata"}
assert set(item.keys()) == expected_keys, f"Dataset item keys for {sensor['name']} are not as expected."

# Validate tensor properties
assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor['name']}."
assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor['name']}."
assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor['name']}."

assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor['name']}."
assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor['name']}."
assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor['name']}."

assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor['name']}."
assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor['name']}."
assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor['name']}."

assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor['name']}."
assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor['name']}"

print(f"Metadata for {sensor['name']}: {item['metadata']}\n")