-
-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of other sensors #134
base: main
Are you sure you want to change the base?
Changes from 8 commits
2a8c6e8
c70d3c1
9041657
999eafb
7440174
7759818
c400f0c
3476d28
0e46fb8
c75fdf8
3577ca7
1b08799
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Main import for the complete models""" | ||
|
||
from .data.nnjai_wrapp import SensorDataset, collate_fn | ||
from .models.analysis import GraphWeatherAssimilator | ||
from .models.forecast import GraphWeatherForecaster |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
"""Dataloaders and data processing utilities""" | ||
|
||
from .nnjai_wrapp import SensorDataset, collate_fn |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,106 @@ | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove this file as its not used now? |
||||||
A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS | ||||||
|
||||||
The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and | ||||||
variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. | ||||||
""" | ||||||
|
||||||
import numpy as np | ||||||
import torch | ||||||
from torch.utils.data import Dataset | ||||||
|
||||||
try: | ||||||
from nnja import DataCatalog | ||||||
except ImportError: | ||||||
print( | ||||||
"NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" | ||||||
) | ||||||
|
||||||
|
||||||
class SensorDataset(Dataset): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We might (probably will, as I am looking into adding other, non NNJA sensors) have other sensors that won't fit in this dataset. I think naming this just |
||||||
"""A custom PyTorch Dataset for handling various sensor data.""" | ||||||
|
||||||
def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): | ||||||
"""Initialize the dataset loader for various sensors. | ||||||
|
||||||
Args: | ||||||
dataset_name: Name of the dataset to load. | ||||||
time: Specific timestamp to filter the data. | ||||||
primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). | ||||||
additional_variables: List of additional variables to include in metadata. | ||||||
sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS) | ||||||
""" | ||||||
self.dataset_name = dataset_name | ||||||
self.time = time | ||||||
self.primary_descriptors = primary_descriptors | ||||||
self.additional_variables = additional_variables | ||||||
self.sensor_type = sensor_type # New argument for selecting sensor type | ||||||
|
||||||
# Load data catalog and dataset | ||||||
self.catalog = DataCatalog(skip_manifest=True) | ||||||
self.dataset = self.catalog[self.dataset_name] | ||||||
self.dataset.load_manifest() | ||||||
|
||||||
if self.sensor_type == "AMSU": | ||||||
self.dataset = self.dataset.sel( | ||||||
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)] | ||||||
) | ||||||
elif self.sensor_type == "ATMS": | ||||||
self.dataset = self.dataset.sel( | ||||||
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)] | ||||||
) | ||||||
elif self.sensor_type == "MHS": | ||||||
self.dataset = self.dataset.sel( | ||||||
time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)] | ||||||
) | ||||||
elif self.sensor_type == "IASI": | ||||||
self.dataset = self.dataset.sel( | ||||||
time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)] | ||||||
) | ||||||
elif self.sensor_type == "CrIS": | ||||||
self.dataset = self.dataset.sel( | ||||||
time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] | ||||||
) | ||||||
else: | ||||||
raise ValueError(f"Unsupported sensor type: {self.sensor_type}") | ||||||
|
||||||
self.dataframe = self.dataset.load_dataset(engine="pandas") | ||||||
|
||||||
for col in primary_descriptors: | ||||||
if col not in self.dataframe.columns: | ||||||
raise ValueError(f"The dataset must include a '{col}' column.") | ||||||
|
||||||
self.metadata_columns = [ | ||||||
col for col in self.dataframe.columns if col not in self.primary_descriptors | ||||||
] | ||||||
|
||||||
def __len__(self): | ||||||
"""Return the total number of samples in the dataset.""" | ||||||
return len(self.dataframe) | ||||||
|
||||||
def __getitem__(self, index): | ||||||
"""Return the observation and metadata for a given index.""" | ||||||
row = self.dataframe.iloc[index] | ||||||
time = row["OBS_TIMESTAMP"].timestamp() | ||||||
latitude = row["LAT"] | ||||||
longitude = row["LON"] | ||||||
metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) | ||||||
|
||||||
return { | ||||||
"timestamp": torch.tensor(time, dtype=torch.float32), | ||||||
"latitude": torch.tensor(latitude, dtype=torch.float32), | ||||||
"longitude": torch.tensor(longitude, dtype=torch.float32), | ||||||
"metadata": torch.from_numpy(metadata), | ||||||
} | ||||||
|
||||||
|
||||||
def collate_fn(batch): | ||||||
"""Custom collate function to handle batching of dictionary data. | ||||||
|
||||||
Args: | ||||||
batch: List of dictionaries from __getitem__ | ||||||
|
||||||
Returns: | ||||||
Single dictionary with batched tensors | ||||||
""" | ||||||
return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,166 @@ | ||||||
""" | ||||||
Tests for the nnjai_wrapp module in the graph_weather package. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
This file contains unit tests for AMSUDataset and collate_fn functions. | ||||||
""" | ||||||
|
||||||
from datetime import datetime | ||||||
from unittest.mock import MagicMock, patch | ||||||
|
||||||
import pytest | ||||||
import torch | ||||||
|
||||||
from graph_weather.data.nnjai_wrapp import SensorDataset, collate_fn | ||||||
|
||||||
@pytest.fixture | ||||||
def mock_datacatalog(): | ||||||
""" | ||||||
Fixture to mock the DataCatalog for unit tests to avoid actual data loading. | ||||||
This mock provides a mock dataset with predefined columns and values. | ||||||
""" | ||||||
with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: | ||||||
# Mock dataset structure | ||||||
mock_df = MagicMock() | ||||||
mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"] | ||||||
|
||||||
# Define a mock row | ||||||
class MockRow: | ||||||
def __getitem__(self, key): | ||||||
data = { | ||||||
"OBS_TIMESTAMP": datetime.now(), | ||||||
"LAT": 45.0, | ||||||
"LON": -120.0, | ||||||
"TMBR_00001": 250.0, | ||||||
"TMBR_00002": 260.0, | ||||||
} | ||||||
return data.get(key, None) | ||||||
|
||||||
# Configure mock dataset | ||||||
mock_row = MockRow() | ||||||
mock_df.iloc = MagicMock() | ||||||
mock_df.iloc.__getitem__.return_value = mock_row | ||||||
mock_df.__len__.return_value = 100 | ||||||
|
||||||
mock_dataset = MagicMock() | ||||||
mock_dataset.load_dataset.return_value = mock_df | ||||||
mock_dataset.sel.return_value = mock_dataset | ||||||
mock_dataset.load_manifest = MagicMock() | ||||||
|
||||||
mock.return_value.__getitem__.return_value = mock_dataset | ||||||
yield mock | ||||||
|
||||||
|
||||||
def test_sensor_dataset(mock_datacatalog): | ||||||
""" | ||||||
Test the SensorDataset class to ensure proper data loading and tensor structure for different sensors. | ||||||
""" | ||||||
# Test for AMSU dataset | ||||||
dataset_name = "amsu-1bamua-NC021023" | ||||||
time = datetime(2021, 1, 1, 0, 0) # Using datetime object instead of string | ||||||
primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] | ||||||
additional_variables = ["TMBR_00001", "TMBR_00002"] | ||||||
dataset = SensorDataset(dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these tests still pass? I think they need to be updated for the new names. Ideally, there should be tests for all the different sensor types as well in |
||||||
|
||||||
# Test dataset length | ||||||
assert len(dataset) > 0, "Dataset should not be empty." | ||||||
|
||||||
item = dataset[0] | ||||||
expected_keys = {"timestamp", "latitude", "longitude", "metadata"} | ||||||
assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected." | ||||||
|
||||||
# Validate tensor properties | ||||||
assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." | ||||||
assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." | ||||||
assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." | ||||||
assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." | ||||||
assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." | ||||||
assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." | ||||||
assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." | ||||||
assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." | ||||||
assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." | ||||||
assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." | ||||||
assert item["metadata"].shape == (len(additional_variables),), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." | ||||||
assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." | ||||||
|
||||||
|
||||||
def test_collate_function(): | ||||||
""" | ||||||
Test the collate_fn function to ensure proper batching of dataset items. | ||||||
""" | ||||||
# Mock a batch of items | ||||||
batch_size = 4 | ||||||
metadata_size = 2 | ||||||
mock_batch = [ | ||||||
{ | ||||||
"timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), | ||||||
"latitude": torch.tensor(45.0, dtype=torch.float32), | ||||||
"longitude": torch.tensor(-120.0, dtype=torch.float32), | ||||||
"metadata": torch.randn(metadata_size, dtype=torch.float32), | ||||||
} | ||||||
for _ in range(batch_size) | ||||||
] | ||||||
|
||||||
# Collate the batch | ||||||
batched = collate_fn(mock_batch) | ||||||
|
||||||
# Validate batched shapes and types | ||||||
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." | ||||||
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." | ||||||
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." | ||||||
assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch." | ||||||
assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." | ||||||
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." | ||||||
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." | ||||||
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch." | ||||||
|
||||||
|
||||||
def test_sensor_datasets(mock_datacatalog): | ||||||
""" | ||||||
Test various sensor datasets (AMSU-A, ATMS, MHS, IASI, CrIS) to ensure they load properly | ||||||
and print the relevant information. | ||||||
""" | ||||||
# Define datasets and associated parameters for different sensors | ||||||
sensors = [ | ||||||
{"name": "amsu-1bamua-NC021023", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "AMSU"}, | ||||||
{"name": "atms-atms-NC021203", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "ATMS"}, | ||||||
{"name": "mhs-1bmhs-NC021027", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "MHS"}, | ||||||
{"name": "iasi-mtiasi-NC021241", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["IASIL1CB"], "sensor_type": "IASI"}, | ||||||
{"name": "cris-crisf4-NC021206", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["SRAD01_00001", "SRAD01_00002"], "sensor_type": "CrIS"} | ||||||
] | ||||||
|
||||||
# Loop through each sensor and load the dataset | ||||||
for sensor in sensors: | ||||||
print(f"\nTesting sensor: {sensor['name']}") | ||||||
|
||||||
# Create the dataset instance | ||||||
dataset = SensorDataset(sensor['name'], sensor['time'], sensor['primary_descriptors'], sensor['additional_variables'], sensor_type=sensor['sensor_type']) | ||||||
|
||||||
# Print dataset length | ||||||
print(f"Dataset length for {sensor['name']}: {len(dataset)}") | ||||||
|
||||||
# Retrieve and print the first item | ||||||
item = dataset[0] | ||||||
print(f"First item from {sensor['name']}:") | ||||||
print(item) | ||||||
|
||||||
# Ensure the dataset item structure is correct | ||||||
expected_keys = {"timestamp", "latitude", "longitude", "metadata"} | ||||||
assert set(item.keys()) == expected_keys, f"Dataset item keys for {sensor['name']} are not as expected." | ||||||
|
||||||
# Validate tensor properties | ||||||
assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor['name']}." | ||||||
assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor['name']}." | ||||||
assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor['name']}." | ||||||
|
||||||
assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor['name']}." | ||||||
assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor['name']}." | ||||||
assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor['name']}." | ||||||
|
||||||
assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor['name']}." | ||||||
assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor['name']}." | ||||||
assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor['name']}." | ||||||
|
||||||
assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor['name']}." | ||||||
assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor['name']}" | ||||||
|
||||||
print(f"Metadata for {sensor['name']}: {item['metadata']}\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should actually import this here, since we have
nnja
as an optional dependency. I think to import this dataset, like in the tests, we should just dofrom graph_weather.data.nnja_ai import SensorDataset
for example. I also renamed this on the main branch, but could you update thennjai_wrapp.py
to be namednnja_ai
as I think that is maybe more consistent naming.