From 2a8c6e834359d4e0b5179334c96a6076adb996e2 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 20 Jan 2025 21:11:10 +0530 Subject: [PATCH 01/14] nnjai support --- graph_weather/data/nnjai_wrapp.py | 74 +++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 graph_weather/data/nnjai_wrapp.py diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py new file mode 100644 index 00000000..64b36e0e --- /dev/null +++ b/graph_weather/data/nnjai_wrapp.py @@ -0,0 +1,74 @@ +import numpy as np +import pandas as pd +from nnja.io import _check_authentication +from torch.utils.data import Dataset, DataLoader + +if _check_authentication(): + from nnja import DataCatalog + + class AMSUDataset(Dataset): + def __init__(self, dataset_name, time, primary_descriptors, additional_variables): + """ + Initialize the AMSU dataset loader. + :param dataset_name: Name of the dataset to load. + :param time: Specific timestamp to filter the data. + :param primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). + :param additional_variables: List of additional variables to include in metadata. + """ + self.dataset_name = dataset_name + self.time = time + self.primary_descriptors = primary_descriptors + self.additional_variables = additional_variables + + # Load data catalog and dataset + self.catalog = DataCatalog(skip_manifest=True) + self.dataset = self.catalog[self.dataset_name] + self.dataset.load_manifest() + + self.dataset = self.dataset.sel(time=self.time, variables=self.primary_descriptors + self.additional_variables) + self.dataframe = self.dataset.load_dataset(engine='pandas') + + for col in primary_descriptors: + if col not in self.dataframe.columns: + raise ValueError(f"The dataset must include a '{col}' column.") + + self.metadata_columns = [ + col for col in self.dataframe.columns if col not in self.primary_descriptors + ] + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, index): + """ + Returns the observation and metadata for a given index. + :param index: Index of the observation to retrieve. + :return: A tuple (time, latitude, longitude, metadata). + """ + row = self.dataframe.iloc[index] + time = row["OBS_TIMESTAMP"].timestamp() + latitude = row["LAT"] + longitude = row["LON"] + metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) + return time, latitude, longitude, metadata + + # Configuration + dataset_name = "amsua-1bamua-NC021023" + time = "2021-01-01 00Z" + primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] + additional_variables = ["TMBR_00001"] + + # Initialize dataset + amsu_dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) + + # Use DataLoader without batching + data_loader = DataLoader(amsu_dataset, shuffle=True) + + # Example usage + for time, latitude, longitude, metadata in data_loader: + print("Time:", time) + print("Latitude:", latitude) + print("Longitude:", longitude) + print("Metadata:", metadata) +else: + print("Install nnjai lib. pip install git+https://github.com/brightbandtech/nnja-ai.git") From c70d3c16d970ffc0c491397bb8893b9aea05ffd9 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 20 Jan 2025 21:51:54 +0530 Subject: [PATCH 02/14] ruff format --- graph_weather/data/nnjai_wrapp.py | 36 +++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index 64b36e0e..b40230a1 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -1,18 +1,33 @@ +""" +This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. + +The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and +variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. +""" + import numpy as np -import pandas as pd from nnja.io import _check_authentication -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset if _check_authentication(): from nnja import DataCatalog class AMSUDataset(Dataset): + """ + A custom PyTorch Dataset for handling AMSU data. + + This dataset retrieves observations and their metadata, filtered by the provided time and + variable descriptors. + """ + def __init__(self, dataset_name, time, primary_descriptors, additional_variables): """ Initialize the AMSU dataset loader. + :param dataset_name: Name of the dataset to load. :param time: Specific timestamp to filter the data. - :param primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). + :param primary_descriptors: List of primary descriptor variables to include (e.g., + OBS_TIMESTAMP, LAT, LON). :param additional_variables: List of additional variables to include in metadata. """ self.dataset_name = dataset_name @@ -25,8 +40,13 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables self.dataset = self.catalog[self.dataset_name] self.dataset.load_manifest() - self.dataset = self.dataset.sel(time=self.time, variables=self.primary_descriptors + self.additional_variables) - self.dataframe = self.dataset.load_dataset(engine='pandas') + self.dataset = self.dataset.sel( + time=self.time, + variables=self.primary_descriptors + self.additional_variables, + ) + self.dataframe = self.dataset.load_dataset( + engine="pandas" + ) for col in primary_descriptors: if col not in self.dataframe.columns: @@ -37,16 +57,20 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables ] def __len__(self): + """ + Returns the total number of samples in the dataset. + """ return len(self.dataframe) def __getitem__(self, index): """ Returns the observation and metadata for a given index. + :param index: Index of the observation to retrieve. :return: A tuple (time, latitude, longitude, metadata). """ row = self.dataframe.iloc[index] - time = row["OBS_TIMESTAMP"].timestamp() + time = row["OBS_TIMESTAMP"].timestamp() latitude = row["LAT"] longitude = row["LON"] metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) From 9041657ead402eaf8b7e476058cb527a43a0cb25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:21:35 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/data/nnjai_wrapp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index b40230a1..4b47132f 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -1,7 +1,7 @@ """ This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. -The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and +The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. """ @@ -16,7 +16,7 @@ class AMSUDataset(Dataset): """ A custom PyTorch Dataset for handling AMSU data. - This dataset retrieves observations and their metadata, filtered by the provided time and + This dataset retrieves observations and their metadata, filtered by the provided time and variable descriptors. """ @@ -26,7 +26,7 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables :param dataset_name: Name of the dataset to load. :param time: Specific timestamp to filter the data. - :param primary_descriptors: List of primary descriptor variables to include (e.g., + :param primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). :param additional_variables: List of additional variables to include in metadata. """ @@ -44,9 +44,7 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables time=self.time, variables=self.primary_descriptors + self.additional_variables, ) - self.dataframe = self.dataset.load_dataset( - engine="pandas" - ) + self.dataframe = self.dataset.load_dataset(engine="pandas") for col in primary_descriptors: if col not in self.dataframe.columns: From 999eafb67fe4a0a12b027b7578151cd20ab62529 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 20 Jan 2025 22:50:21 +0530 Subject: [PATCH 04/14] changes as per requested --- graph_weather/data/nnjai_wrapp.py | 189 ++++++++++++++++++------------ 1 file changed, 112 insertions(+), 77 deletions(-) diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index 4b47132f..eb7cc287 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -1,96 +1,131 @@ """ -This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. +A custom PyTorch Dataset implementation for AMSU datasets. +This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. """ import numpy as np -from nnja.io import _check_authentication +import torch from torch.utils.data import DataLoader, Dataset -if _check_authentication(): +try: from nnja import DataCatalog - - class AMSUDataset(Dataset): +except ImportError: + print("NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`") + +class AMSUDataset(Dataset): + """A custom PyTorch Dataset for handling AMSU data. + + This dataset retrieves observations and their metadata, filtered by the provided time and + variable descriptors. + """ + + def __init__(self, dataset_name, time, primary_descriptors, additional_variables): + """Initialize the AMSU dataset loader. + + Args: + dataset_name: Name of the dataset to load. + time: Specific timestamp to filter the data. + primary_descriptors: List of primary descriptor variables to include (e.g., + OBS_TIMESTAMP, LAT, LON). + additional_variables: List of additional variables to include in metadata. """ - A custom PyTorch Dataset for handling AMSU data. - - This dataset retrieves observations and their metadata, filtered by the provided time and - variable descriptors. + self.dataset_name = dataset_name + self.time = time + self.primary_descriptors = primary_descriptors + self.additional_variables = additional_variables + + # Load data catalog and dataset + self.catalog = DataCatalog(skip_manifest=True) + self.dataset = self.catalog[self.dataset_name] + self.dataset.load_manifest() + + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + self.additional_variables + ) + self.dataframe = self.dataset.load_dataset(engine="pandas") + + for col in primary_descriptors: + if col not in self.dataframe.columns: + raise ValueError(f"The dataset must include a '{col}' column.") + + self.metadata_columns = [ + col for col in self.dataframe.columns if col not in self.primary_descriptors + ] + + def __len__(self): + """Return the total number of samples in the dataset.""" + return len(self.dataframe) + + def __getitem__(self, index): + """Return the observation and metadata for a given index. + + Args: + index: Index of the observation to retrieve. + + Returns: + A dictionary containing timestamp, latitude, longitude, and metadata. """ - - def __init__(self, dataset_name, time, primary_descriptors, additional_variables): - """ - Initialize the AMSU dataset loader. - - :param dataset_name: Name of the dataset to load. - :param time: Specific timestamp to filter the data. - :param primary_descriptors: List of primary descriptor variables to include (e.g., - OBS_TIMESTAMP, LAT, LON). - :param additional_variables: List of additional variables to include in metadata. - """ - self.dataset_name = dataset_name - self.time = time - self.primary_descriptors = primary_descriptors - self.additional_variables = additional_variables - - # Load data catalog and dataset - self.catalog = DataCatalog(skip_manifest=True) - self.dataset = self.catalog[self.dataset_name] - self.dataset.load_manifest() - - self.dataset = self.dataset.sel( - time=self.time, - variables=self.primary_descriptors + self.additional_variables, - ) - self.dataframe = self.dataset.load_dataset(engine="pandas") - - for col in primary_descriptors: - if col not in self.dataframe.columns: - raise ValueError(f"The dataset must include a '{col}' column.") - - self.metadata_columns = [ - col for col in self.dataframe.columns if col not in self.primary_descriptors - ] - - def __len__(self): - """ - Returns the total number of samples in the dataset. - """ - return len(self.dataframe) - - def __getitem__(self, index): - """ - Returns the observation and metadata for a given index. - - :param index: Index of the observation to retrieve. - :return: A tuple (time, latitude, longitude, metadata). - """ - row = self.dataframe.iloc[index] - time = row["OBS_TIMESTAMP"].timestamp() - latitude = row["LAT"] - longitude = row["LON"] - metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) - return time, latitude, longitude, metadata - + row = self.dataframe.iloc[index] + time = row["OBS_TIMESTAMP"].timestamp() + latitude = row["LAT"] + longitude = row["LON"] + metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) + + return { + "timestamp": torch.tensor(time, dtype=torch.float32), + "latitude": torch.tensor(latitude, dtype=torch.float32), + "longitude": torch.tensor(longitude, dtype=torch.float32), + "metadata": torch.from_numpy(metadata) + } + +def collate_fn(batch): + """Custom collate function to handle batching of dictionary data. + + Args: + batch: List of dictionaries from __getitem__ + + Returns: + Single dictionary with batched tensors + """ + return { + key: torch.stack([item[key] for item in batch]) + for key in batch[0].keys() + } + +if __name__ == "__main__": # Configuration dataset_name = "amsua-1bamua-NC021023" time = "2021-01-01 00Z" primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] additional_variables = ["TMBR_00001"] - + # Initialize dataset amsu_dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) - - # Use DataLoader without batching - data_loader = DataLoader(amsu_dataset, shuffle=True) - - # Example usage - for time, latitude, longitude, metadata in data_loader: - print("Time:", time) - print("Latitude:", latitude) - print("Longitude:", longitude) - print("Metadata:", metadata) -else: - print("Install nnjai lib. pip install git+https://github.com/brightbandtech/nnja-ai.git") + + batch_size = 4 + data_loader = DataLoader( + amsu_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn + ) + + # Example usage with batched data + for batch in data_loader: + print(f"Batch size: {batch['timestamp'].shape[0]}") + print("Timestamps shape:", batch["timestamp"].shape) + print("Latitudes shape:", batch["latitude"].shape) + print("Longitudes shape:", batch["longitude"].shape) + print("Metadata shape:", batch["metadata"].shape) + + for i in range(batch_size): + print(f"\nItem {i}:") + print("Time:", batch["timestamp"][i].item()) + print("Latitude:", batch["latitude"][i].item()) + print("Longitude:", batch["longitude"][i].item()) + print("Metadata:", batch["metadata"][i]) + + break \ No newline at end of file From 7440174e864dd1e47f12ae1eb68fc8d3b61c0d42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:53:17 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/data/nnjai_wrapp.py | 59 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index eb7cc287..c0e022d0 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -13,18 +13,21 @@ try: from nnja import DataCatalog except ImportError: - print("NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`") + print( + "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" + ) + class AMSUDataset(Dataset): """A custom PyTorch Dataset for handling AMSU data. - + This dataset retrieves observations and their metadata, filtered by the provided time and variable descriptors. """ - + def __init__(self, dataset_name, time, primary_descriptors, additional_variables): """Initialize the AMSU dataset loader. - + Args: dataset_name: Name of the dataset to load. time: Specific timestamp to filter the data. @@ -36,35 +39,35 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables self.time = time self.primary_descriptors = primary_descriptors self.additional_variables = additional_variables - + # Load data catalog and dataset self.catalog = DataCatalog(skip_manifest=True) self.dataset = self.catalog[self.dataset_name] self.dataset.load_manifest() - + self.dataset = self.dataset.sel( time=self.time, variables=self.primary_descriptors + self.additional_variables ) self.dataframe = self.dataset.load_dataset(engine="pandas") - + for col in primary_descriptors: if col not in self.dataframe.columns: raise ValueError(f"The dataset must include a '{col}' column.") - + self.metadata_columns = [ col for col in self.dataframe.columns if col not in self.primary_descriptors ] - + def __len__(self): """Return the total number of samples in the dataset.""" return len(self.dataframe) - + def __getitem__(self, index): """Return the observation and metadata for a given index. - + Args: index: Index of the observation to retrieve. - + Returns: A dictionary containing timestamp, latitude, longitude, and metadata. """ @@ -73,27 +76,26 @@ def __getitem__(self, index): latitude = row["LAT"] longitude = row["LON"] metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) - + return { "timestamp": torch.tensor(time, dtype=torch.float32), "latitude": torch.tensor(latitude, dtype=torch.float32), "longitude": torch.tensor(longitude, dtype=torch.float32), - "metadata": torch.from_numpy(metadata) + "metadata": torch.from_numpy(metadata), } + def collate_fn(batch): """Custom collate function to handle batching of dictionary data. - + Args: batch: List of dictionaries from __getitem__ - + Returns: Single dictionary with batched tensors """ - return { - key: torch.stack([item[key] for item in batch]) - for key in batch[0].keys() - } + return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} + if __name__ == "__main__": # Configuration @@ -101,18 +103,15 @@ def collate_fn(batch): time = "2021-01-01 00Z" primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] additional_variables = ["TMBR_00001"] - + # Initialize dataset amsu_dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) - + batch_size = 4 data_loader = DataLoader( - amsu_dataset, - batch_size=batch_size, - shuffle=True, - collate_fn=collate_fn + amsu_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn ) - + # Example usage with batched data for batch in data_loader: print(f"Batch size: {batch['timestamp'].shape[0]}") @@ -120,12 +119,12 @@ def collate_fn(batch): print("Latitudes shape:", batch["latitude"].shape) print("Longitudes shape:", batch["longitude"].shape) print("Metadata shape:", batch["metadata"].shape) - + for i in range(batch_size): print(f"\nItem {i}:") print("Time:", batch["timestamp"][i].item()) print("Latitude:", batch["latitude"][i].item()) print("Longitude:", batch["longitude"][i].item()) print("Metadata:", batch["metadata"][i]) - - break \ No newline at end of file + + break From 77598182d37e210a2417ef7d89a5e720e872adc8 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Tue, 21 Jan 2025 22:49:14 +0530 Subject: [PATCH 06/14] test_nnjai.py --- graph_weather/__init__.py | 1 + graph_weather/data/__init__.py | 2 + graph_weather/data/nnjai_wrapp.py | 38 +-------- tests/test_nnjai.py | 133 ++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 34 deletions(-) create mode 100644 tests/test_nnjai.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 1758fd49..bdf798b0 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -2,3 +2,4 @@ from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster +from .data.nnjai_wrapp import (AMSUDataset,collate_fn) \ No newline at end of file diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 6eb48e01..88796a44 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1 +1,3 @@ """Dataloaders and data processing utilities""" + +from .nnjai_wrapp import (AMSUDataset,collate_fn) \ No newline at end of file diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index c0e022d0..f87539d0 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -94,37 +94,7 @@ def collate_fn(batch): Returns: Single dictionary with batched tensors """ - return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} - - -if __name__ == "__main__": - # Configuration - dataset_name = "amsua-1bamua-NC021023" - time = "2021-01-01 00Z" - primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] - additional_variables = ["TMBR_00001"] - - # Initialize dataset - amsu_dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) - - batch_size = 4 - data_loader = DataLoader( - amsu_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn - ) - - # Example usage with batched data - for batch in data_loader: - print(f"Batch size: {batch['timestamp'].shape[0]}") - print("Timestamps shape:", batch["timestamp"].shape) - print("Latitudes shape:", batch["latitude"].shape) - print("Longitudes shape:", batch["longitude"].shape) - print("Metadata shape:", batch["metadata"].shape) - - for i in range(batch_size): - print(f"\nItem {i}:") - print("Time:", batch["timestamp"][i].item()) - print("Latitude:", batch["latitude"][i].item()) - print("Longitude:", batch["longitude"][i].item()) - print("Metadata:", batch["metadata"][i]) - - break + return { + key: torch.stack([item[key] for item in batch]) + for key in batch[0].keys() + } diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py new file mode 100644 index 00000000..8741efe8 --- /dev/null +++ b/tests/test_nnjai.py @@ -0,0 +1,133 @@ +""" +Tests for the nnjai_wrapp module in the graph_weather package. + +This file contains unit tests for AMSUDataset and collate_fn functions. +""" + + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from graph_weather.data.nnjai_wrapp import AMSUDataset,collate_fn + +# Mock the DataCatalog to avoid actual data loading +@pytest.fixture +def mock_datacatalog(): + """ + Fixture to mock the DataCatalog for unit tests to avoid actual data loading. + + This mock provides a mock dataset with predefined columns and values. + """ + with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: + # Mock dataset structure + mock_df = MagicMock() + mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"] + + # Define a mock row + class MockRow: + def __getitem__(self, key): + data = { + "OBS_TIMESTAMP": datetime.now(), + "LAT": 45.0, + "LON": -120.0, + "TMBR_00001": 250.0, + "TMBR_00002": 260.0, + } + return data.get(key, None) + + # Configure mock dataset + mock_row = MockRow() + mock_df.iloc = MagicMock() + mock_df.iloc.__getitem__.return_value = mock_row + mock_df.__len__.return_value = 100 + + mock_dataset = MagicMock() + mock_dataset.load_dataset.return_value = mock_df + mock_dataset.sel.return_value = mock_dataset + mock_dataset.load_manifest = MagicMock() + + mock.return_value.__getitem__.return_value = mock_dataset + yield mock + + +def test_amsu_dataset(mock_datacatalog): + """ + Test the AMSUDataset class to ensure proper data loading and tensor structure. + + This test validates the AMSUDataset class for its ability to load the dataset + correctly, check for the appropriate tensor properties, and ensure the keys + and data types match expectations. + """ + # Initialize dataset parameters + dataset_name = "amsua-1bamua-NC021023" + time = "2021-01-01 00Z" + primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] + additional_variables = ["TMBR_00001", "TMBR_00002"] + + dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) + + # Test dataset length + assert len(dataset) > 0, "Dataset should not be empty." + + item = dataset[0] + expected_keys = {"timestamp", "latitude", "longitude", "metadata"} + assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected." + + # Validate tensor properties + assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." + assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." + assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." + + assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." + assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." + assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." + + assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." + assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." + assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." + + assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." + assert item["metadata"].shape == (len(additional_variables),), ( + f"Metadata shape mismatch. Expected ({len(additional_variables)},)." + ) + assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." + + +def test_collate_function(): + """ + Test the collate_fn function to ensure proper batching of dataset items. + + This test checks that the collate_fn properly batches the timestamp, latitude, + longitude, and metadata fields of the dataset, ensuring correct shapes and data types. + """ + # Mock a batch of items + batch_size = 4 + metadata_size = 2 + mock_batch = [ + { + "timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), + "latitude": torch.tensor(45.0, dtype=torch.float32), + "longitude": torch.tensor(-120.0, dtype=torch.float32), + "metadata": torch.randn(metadata_size, dtype=torch.float32), + } + for _ in range(batch_size) + ] + + # Collate the batch + batched = collate_fn(mock_batch) + + # Validate batched shapes and types + assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." + assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." + assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." + assert batched["metadata"].shape == (batch_size, metadata_size), ( + "Metadata batch shape mismatch." + ) + + assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." + assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." + assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." + assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch." From c400f0c2a5c1eabc90c53686a5b083075b4c47ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 17:28:22 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/__init__.py | 2 +- graph_weather/data/__init__.py | 2 +- graph_weather/data/nnjai_wrapp.py | 7 ++----- tests/test_nnjai.py | 17 +++++++++-------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index bdf798b0..74719deb 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,5 +1,5 @@ """Main import for the complete models""" +from .data.nnjai_wrapp import AMSUDataset, collate_fn from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster -from .data.nnjai_wrapp import (AMSUDataset,collate_fn) \ No newline at end of file diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 88796a44..be2e2805 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,3 +1,3 @@ """Dataloaders and data processing utilities""" -from .nnjai_wrapp import (AMSUDataset,collate_fn) \ No newline at end of file +from .nnjai_wrapp import AMSUDataset, collate_fn diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index f87539d0..61c209d2 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -8,7 +8,7 @@ import numpy as np import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset try: from nnja import DataCatalog @@ -94,7 +94,4 @@ def collate_fn(batch): Returns: Single dictionary with batched tensors """ - return { - key: torch.stack([item[key] for item in batch]) - for key in batch[0].keys() - } + return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index 8741efe8..2cbf7508 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -4,14 +4,14 @@ This file contains unit tests for AMSUDataset and collate_fn functions. """ - from datetime import datetime from unittest.mock import MagicMock, patch import pytest import torch -from graph_weather.data.nnjai_wrapp import AMSUDataset,collate_fn +from graph_weather.data.nnjai_wrapp import AMSUDataset, collate_fn + # Mock the DataCatalog to avoid actual data loading @pytest.fixture @@ -90,9 +90,9 @@ def test_amsu_dataset(mock_datacatalog): assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." - assert item["metadata"].shape == (len(additional_variables),), ( - f"Metadata shape mismatch. Expected ({len(additional_variables)},)." - ) + assert item["metadata"].shape == ( + len(additional_variables), + ), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." @@ -123,9 +123,10 @@ def test_collate_function(): assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." - assert batched["metadata"].shape == (batch_size, metadata_size), ( - "Metadata batch shape mismatch." - ) + assert batched["metadata"].shape == ( + batch_size, + metadata_size, + ), "Metadata batch shape mismatch." assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." From 3476d28994d7955781a17d98447a4d3934d5c574 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 31 Jan 2025 17:43:32 +0530 Subject: [PATCH 08/14] NNJAI with different tensors --- graph_weather/__init__.py | 2 +- graph_weather/data/__init__.py | 2 +- graph_weather/data/nnjai_wrapp.py | 55 +++++++++++-------- tests/test_nnjai.py | 90 +++++++++++++++++++++---------- 4 files changed, 95 insertions(+), 54 deletions(-) diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 74719deb..26625e60 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,5 +1,5 @@ """Main import for the complete models""" -from .data.nnjai_wrapp import AMSUDataset, collate_fn +from .data.nnjai_wrapp import SensorDataset, collate_fn from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index be2e2805..ca21c387 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,3 +1,3 @@ """Dataloaders and data processing utilities""" -from .nnjai_wrapp import AMSUDataset, collate_fn +from .nnjai_wrapp import SensorDataset, collate_fn diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py index 61c209d2..d66c9465 100644 --- a/graph_weather/data/nnjai_wrapp.py +++ b/graph_weather/data/nnjai_wrapp.py @@ -1,7 +1,6 @@ """ -A custom PyTorch Dataset implementation for AMSU datasets. +A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS -This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. """ @@ -18,36 +17,53 @@ ) -class AMSUDataset(Dataset): - """A custom PyTorch Dataset for handling AMSU data. +class SensorDataset(Dataset): + """A custom PyTorch Dataset for handling various sensor data.""" - This dataset retrieves observations and their metadata, filtered by the provided time and - variable descriptors. - """ - - def __init__(self, dataset_name, time, primary_descriptors, additional_variables): - """Initialize the AMSU dataset loader. + def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): + """Initialize the dataset loader for various sensors. Args: dataset_name: Name of the dataset to load. time: Specific timestamp to filter the data. - primary_descriptors: List of primary descriptor variables to include (e.g., - OBS_TIMESTAMP, LAT, LON). + primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). additional_variables: List of additional variables to include in metadata. + sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS) """ self.dataset_name = dataset_name self.time = time self.primary_descriptors = primary_descriptors self.additional_variables = additional_variables + self.sensor_type = sensor_type # New argument for selecting sensor type # Load data catalog and dataset self.catalog = DataCatalog(skip_manifest=True) self.dataset = self.catalog[self.dataset_name] self.dataset.load_manifest() - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + self.additional_variables - ) + if self.sensor_type == "AMSU": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)] + ) + elif self.sensor_type == "ATMS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)] + ) + elif self.sensor_type == "MHS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)] + ) + elif self.sensor_type == "IASI": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)] + ) + elif self.sensor_type == "CrIS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] + ) + else: + raise ValueError(f"Unsupported sensor type: {self.sensor_type}") + self.dataframe = self.dataset.load_dataset(engine="pandas") for col in primary_descriptors: @@ -63,14 +79,7 @@ def __len__(self): return len(self.dataframe) def __getitem__(self, index): - """Return the observation and metadata for a given index. - - Args: - index: Index of the observation to retrieve. - - Returns: - A dictionary containing timestamp, latitude, longitude, and metadata. - """ + """Return the observation and metadata for a given index.""" row = self.dataframe.iloc[index] time = row["OBS_TIMESTAMP"].timestamp() latitude = row["LAT"] diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index 2cbf7508..da706782 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -10,15 +10,12 @@ import pytest import torch -from graph_weather.data.nnjai_wrapp import AMSUDataset, collate_fn +from graph_weather.data.nnjai_wrapp import SensorDataset, collate_fn - -# Mock the DataCatalog to avoid actual data loading @pytest.fixture def mock_datacatalog(): """ Fixture to mock the DataCatalog for unit tests to avoid actual data loading. - This mock provides a mock dataset with predefined columns and values. """ with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: @@ -53,21 +50,16 @@ def __getitem__(self, key): yield mock -def test_amsu_dataset(mock_datacatalog): +def test_sensor_dataset(mock_datacatalog): """ - Test the AMSUDataset class to ensure proper data loading and tensor structure. - - This test validates the AMSUDataset class for its ability to load the dataset - correctly, check for the appropriate tensor properties, and ensure the keys - and data types match expectations. + Test the SensorDataset class to ensure proper data loading and tensor structure for different sensors. """ - # Initialize dataset parameters - dataset_name = "amsua-1bamua-NC021023" - time = "2021-01-01 00Z" + # Test for AMSU dataset + dataset_name = "amsu-1bamua-NC021023" + time = datetime(2021, 1, 1, 0, 0) # Using datetime object instead of string primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] additional_variables = ["TMBR_00001", "TMBR_00002"] - - dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) + dataset = SensorDataset(dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU") # Test dataset length assert len(dataset) > 0, "Dataset should not be empty." @@ -80,28 +72,20 @@ def test_amsu_dataset(mock_datacatalog): assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." - assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." - assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." - assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." - assert item["metadata"].shape == ( - len(additional_variables), - ), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." + assert item["metadata"].shape == (len(additional_variables),), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." def test_collate_function(): """ Test the collate_fn function to ensure proper batching of dataset items. - - This test checks that the collate_fn properly batches the timestamp, latitude, - longitude, and metadata fields of the dataset, ensuring correct shapes and data types. """ # Mock a batch of items batch_size = 4 @@ -123,12 +107,60 @@ def test_collate_function(): assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." - assert batched["metadata"].shape == ( - batch_size, - metadata_size, - ), "Metadata batch shape mismatch." - + assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch." assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch." + + +def test_sensor_datasets(mock_datacatalog): + """ + Test various sensor datasets (AMSU-A, ATMS, MHS, IASI, CrIS) to ensure they load properly + and print the relevant information. + """ + # Define datasets and associated parameters for different sensors + sensors = [ + {"name": "amsu-1bamua-NC021023", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "AMSU"}, + {"name": "atms-atms-NC021203", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "ATMS"}, + {"name": "mhs-1bmhs-NC021027", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "MHS"}, + {"name": "iasi-mtiasi-NC021241", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["IASIL1CB"], "sensor_type": "IASI"}, + {"name": "cris-crisf4-NC021206", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["SRAD01_00001", "SRAD01_00002"], "sensor_type": "CrIS"} + ] + + # Loop through each sensor and load the dataset + for sensor in sensors: + print(f"\nTesting sensor: {sensor['name']}") + + # Create the dataset instance + dataset = SensorDataset(sensor['name'], sensor['time'], sensor['primary_descriptors'], sensor['additional_variables'], sensor_type=sensor['sensor_type']) + + # Print dataset length + print(f"Dataset length for {sensor['name']}: {len(dataset)}") + + # Retrieve and print the first item + item = dataset[0] + print(f"First item from {sensor['name']}:") + print(item) + + # Ensure the dataset item structure is correct + expected_keys = {"timestamp", "latitude", "longitude", "metadata"} + assert set(item.keys()) == expected_keys, f"Dataset item keys for {sensor['name']} are not as expected." + + # Validate tensor properties + assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor['name']}." + assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor['name']}." + assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor['name']}." + + assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor['name']}." + assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor['name']}." + assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor['name']}." + + assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor['name']}." + assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor['name']}." + assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor['name']}." + + assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor['name']}." + assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor['name']}" + + print(f"Metadata for {sensor['name']}: {item['metadata']}\n") \ No newline at end of file From 0e46fb857ed0b217fe5cf0c2aeecbf574b96df50 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Sun, 9 Feb 2025 12:33:31 +0530 Subject: [PATCH 09/14] Refactor: Updated naming for consistency and clarity --- graph_weather/__init__.py | 3 +- graph_weather/data/nnja_ai.py | 106 ++++++++++++++++++++++++++++++++++ tests/test_nnjai.py | 2 +- 3 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 graph_weather/data/nnja_ai.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 26625e60..75a712e7 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,5 +1,6 @@ """Main import for the complete models""" -from .data.nnjai_wrapp import SensorDataset, collate_fn +from graph_weather.data.nnja_ai import NNJADataset, collate_fn from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster +from .models.aurora import LoraLayer, PerceiverProcessor, IntegrationLayer, GenCastConfig, Fengwu_GHRConfig, ValidationError, TransformationError \ No newline at end of file diff --git a/graph_weather/data/nnja_ai.py b/graph_weather/data/nnja_ai.py new file mode 100644 index 00000000..4e7c10d7 --- /dev/null +++ b/graph_weather/data/nnja_ai.py @@ -0,0 +1,106 @@ +""" +A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS + +The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and +variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. +""" + +import numpy as np +import torch +from torch.utils.data import Dataset + +try: + from nnja import DataCatalog +except ImportError: + print( + "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" + ) + + +class NNJADataset(Dataset): + """A custom PyTorch Dataset for handling various sensor data.""" + + def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): + """Initialize the dataset loader for various sensors. + + Args: + dataset_name: Name of the dataset to load. + time: Specific timestamp to filter the data. + primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). + additional_variables: List of additional variables to include in metadata. + sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS) + """ + self.dataset_name = dataset_name + self.time = time + self.primary_descriptors = primary_descriptors + self.additional_variables = additional_variables + self.sensor_type = sensor_type # New argument for selecting sensor type + + # Load data catalog and dataset + self.catalog = DataCatalog(skip_manifest=True) + self.dataset = self.catalog[self.dataset_name] + self.dataset.load_manifest() + + if self.sensor_type == "AMSU": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)] + ) + elif self.sensor_type == "ATMS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)] + ) + elif self.sensor_type == "MHS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)] + ) + elif self.sensor_type == "IASI": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)] + ) + elif self.sensor_type == "CrIS": + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] + ) + else: + raise ValueError(f"Unsupported sensor type: {self.sensor_type}") + + self.dataframe = self.dataset.load_dataset(engine="pandas") + + for col in primary_descriptors: + if col not in self.dataframe.columns: + raise ValueError(f"The dataset must include a '{col}' column.") + + self.metadata_columns = [ + col for col in self.dataframe.columns if col not in self.primary_descriptors + ] + + def __len__(self): + """Return the total number of samples in the dataset.""" + return len(self.dataframe) + + def __getitem__(self, index): + """Return the observation and metadata for a given index.""" + row = self.dataframe.iloc[index] + time = row["OBS_TIMESTAMP"].timestamp() + latitude = row["LAT"] + longitude = row["LON"] + metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) + + return { + "timestamp": torch.tensor(time, dtype=torch.float32), + "latitude": torch.tensor(latitude, dtype=torch.float32), + "longitude": torch.tensor(longitude, dtype=torch.float32), + "metadata": torch.from_numpy(metadata), + } + + +def collate_fn(batch): + """Custom collate function to handle batching of dictionary data. + + Args: + batch: List of dictionaries from __getitem__ + + Returns: + Single dictionary with batched tensors + """ + return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index da706782..77939437 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -10,7 +10,7 @@ import pytest import torch -from graph_weather.data.nnjai_wrapp import SensorDataset, collate_fn +from graph_weather.data.nnja_ai import NNJADataset, collate_fn @pytest.fixture def mock_datacatalog(): From c75fdf8e4883cde6a6d729ed7cab5b0b9c6ae29c Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Sun, 9 Feb 2025 12:37:32 +0530 Subject: [PATCH 10/14] Refactor: data/__init__.py --- graph_weather/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index ca21c387..356c1864 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,3 +1,3 @@ """Dataloaders and data processing utilities""" -from .nnjai_wrapp import SensorDataset, collate_fn +from .nnja_ai import SensorDataset, collate_fn From 3577ca75b8e8ade0929ad9e41636d8165c900c32 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 21 Feb 2025 22:29:39 +0530 Subject: [PATCH 11/14] pytest fixture implementation --- graph_weather/__init__.py | 5 +- graph_weather/data/nnja_ai.py | 2 +- tests/test_nnjai.py | 255 +++++++++++++++++----------------- 3 files changed, 133 insertions(+), 129 deletions(-) diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 75a712e7..9cf9ddc5 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,6 +1,5 @@ """Main import for the complete models""" -from graph_weather.data.nnja_ai import NNJADataset, collate_fn +from .data.nnja_ai import SensorDataset, collate_fn from .models.analysis import GraphWeatherAssimilator -from .models.forecast import GraphWeatherForecaster -from .models.aurora import LoraLayer, PerceiverProcessor, IntegrationLayer, GenCastConfig, Fengwu_GHRConfig, ValidationError, TransformationError \ No newline at end of file +from .models.forecast import GraphWeatherForecaster \ No newline at end of file diff --git a/graph_weather/data/nnja_ai.py b/graph_weather/data/nnja_ai.py index 4e7c10d7..d66c9465 100644 --- a/graph_weather/data/nnja_ai.py +++ b/graph_weather/data/nnja_ai.py @@ -17,7 +17,7 @@ ) -class NNJADataset(Dataset): +class SensorDataset(Dataset): """A custom PyTorch Dataset for handling various sensor data.""" def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index 77939437..19f8414d 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -1,95 +1,154 @@ """ -Tests for the nnjai_wrapp module in the graph_weather package. +Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior. -This file contains unit tests for AMSUDataset and collate_fn functions. +The tests ensure correct handling of data types, shapes, and batch processing for various sensor types. """ -from datetime import datetime +from datetime import datetime from unittest.mock import MagicMock, patch - +import numpy as np import pytest import torch - -from graph_weather.data.nnja_ai import NNJADataset, collate_fn +import pandas as pd + +from graph_weather.data.nnja_ai import SensorDataset, collate_fn + +def get_sensor_variables(sensor_type): + """Helper function to get the correct variables for each sensor type.""" + if sensor_type == "AMSU": + return [f"TMBR_000{i:02d}" for i in range(1, 16)] # 15 channels + elif sensor_type == "ATMS": + return [f"TMBR_000{i:02d}" for i in range(1, 23)] # 22 channels + elif sensor_type == "MHS": + return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels + elif sensor_type == "IASI": + return [f"SCRA_{str(i).zfill(5)}" for i in range(1, 617)] # 616 channels + elif sensor_type == "CrIS": + return [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] # 431 channels + return [] @pytest.fixture def mock_datacatalog(): """ Fixture to mock the DataCatalog for unit tests to avoid actual data loading. - This mock provides a mock dataset with predefined columns and values. """ - with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: - # Mock dataset structure - mock_df = MagicMock() - mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"] - - # Define a mock row - class MockRow: - def __getitem__(self, key): - data = { - "OBS_TIMESTAMP": datetime.now(), - "LAT": 45.0, - "LON": -120.0, - "TMBR_00001": 250.0, - "TMBR_00002": 260.0, - } - return data.get(key, None) - - # Configure mock dataset - mock_row = MockRow() - mock_df.iloc = MagicMock() - mock_df.iloc.__getitem__.return_value = mock_row - mock_df.__len__.return_value = 100 - + with patch("graph_weather.data.nnja_ai.DataCatalog") as mock: + # Create a mock catalog + mock_catalog = MagicMock() + + # Create a mock dataset with direct DataFrame return mock_dataset = MagicMock() - mock_dataset.load_dataset.return_value = mock_df - mock_dataset.sel.return_value = mock_dataset mock_dataset.load_manifest = MagicMock() - - mock.return_value.__getitem__.return_value = mock_dataset + mock_dataset.sel = MagicMock(return_value=mock_dataset) # Return self to chain calls + + def create_mock_df(engine="pandas"): + # Get the sensor type from the mock dataset + sensor_vars = get_sensor_variables(mock_dataset.sensor_type) + + # Create DataFrame with required columns + df = pd.DataFrame({ + "OBS_TIMESTAMP": pd.date_range(start=datetime(2021, 1, 1), periods=100, freq='H'), + "LAT": np.full(100, 45.0), + "LON": np.full(100, -120.0) + }) + + # Add sensor-specific variables + for var in sensor_vars: + df[var] = np.full(100, 250.0) + + return df + + # Set up the mock to return our DataFrame + mock_dataset.load_dataset = create_mock_df + + # Configure the catalog to return our mock dataset + def get_mock_dataset(self, name): + # Set the sensor type based on the requested dataset name + mock_dataset.sensor_type = next( + config["sensor_type"] for config in SENSOR_CONFIGS + if config["name"] == name + ) + return mock_dataset + + mock_catalog.__getitem__ = get_mock_dataset # Fix: Explicitly define the method with `self` + mock.return_value = mock_catalog + yield mock - -def test_sensor_dataset(mock_datacatalog): - """ - Test the SensorDataset class to ensure proper data loading and tensor structure for different sensors. - """ - # Test for AMSU dataset - dataset_name = "amsu-1bamua-NC021023" - time = datetime(2021, 1, 1, 0, 0) # Using datetime object instead of string +# Test configurations +SENSOR_CONFIGS = [ + { + "name": "amsu-1bamua-NC021023", + "sensor_type": "AMSU", + "expected_metadata_size": 15 # 15 TMBR channels + }, + { + "name": "atms-atms-NC021203", + "sensor_type": "ATMS", + "expected_metadata_size": 22 # 22 TMBR channels + }, + { + "name": "mhs-1bmhs-NC021027", + "sensor_type": "MHS", + "expected_metadata_size": 5 # 5 TMBR channels + }, + { + "name": "iasi-mtiasi-NC021241", + "sensor_type": "IASI", + "expected_metadata_size": 616 # 616 SCRA channels + }, + { + "name": "cris-crisf4-NC021206", + "sensor_type": "CrIS", + "expected_metadata_size": 431 # 431 SRAD channels + } +] + +@pytest.mark.parametrize("sensor_config", SENSOR_CONFIGS) +def test_sensor_dataset(mock_datacatalog, sensor_config): + """Test the SensorDataset class for different sensor types.""" + time = datetime(2021, 1, 1, 0, 0) primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] - additional_variables = ["TMBR_00001", "TMBR_00002"] - dataset = SensorDataset(dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU") + + dataset = SensorDataset( + dataset_name=sensor_config["name"], + time=time, + primary_descriptors=primary_descriptors, + additional_variables=get_sensor_variables(sensor_config["sensor_type"]), + sensor_type=sensor_config["sensor_type"] + ) # Test dataset length - assert len(dataset) > 0, "Dataset should not be empty." + assert len(dataset) > 0, f"Dataset should not be empty for {sensor_config['sensor_type']}" + # Test single item structure item = dataset[0] expected_keys = {"timestamp", "latitude", "longitude", "metadata"} - assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected." + assert set(item.keys()) == expected_keys, f"Dataset item keys are not as expected for {sensor_config['sensor_type']}" # Validate tensor properties - assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." - assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." - assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." - assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." - assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." - assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." - assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." - assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." - assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." - assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." - assert item["metadata"].shape == (len(additional_variables),), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." - assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." + assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor_config['sensor_type']}" + assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}" + assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor_config['sensor_type']}" + assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor_config['sensor_type']}" + assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor_config['sensor_type']}" + assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor_config['sensor_type']}" + assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor_config['sensor_type']}" + assert item["metadata"].shape == (sensor_config["expected_metadata_size"],), \ + f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)" + assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor_config['sensor_type']}" def test_collate_function(): - """ - Test the collate_fn function to ensure proper batching of dataset items. - """ - # Mock a batch of items + """Test the collate_fn function to ensure proper batching of dataset items.""" batch_size = 4 - metadata_size = 2 + metadata_size = 15 # Using AMSU size for this test mock_batch = [ { "timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), @@ -100,67 +159,13 @@ def test_collate_function(): for _ in range(batch_size) ] - # Collate the batch batched = collate_fn(mock_batch) - # Validate batched shapes and types - assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." - assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." - assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." - assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch." - assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." - assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." - assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." - assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch." - - -def test_sensor_datasets(mock_datacatalog): - """ - Test various sensor datasets (AMSU-A, ATMS, MHS, IASI, CrIS) to ensure they load properly - and print the relevant information. - """ - # Define datasets and associated parameters for different sensors - sensors = [ - {"name": "amsu-1bamua-NC021023", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "AMSU"}, - {"name": "atms-atms-NC021203", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "ATMS"}, - {"name": "mhs-1bmhs-NC021027", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["TMBR_00001", "TMBR_00002"], "sensor_type": "MHS"}, - {"name": "iasi-mtiasi-NC021241", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["IASIL1CB"], "sensor_type": "IASI"}, - {"name": "cris-crisf4-NC021206", "time": datetime(2021, 1, 1, 0, 0), "primary_descriptors": ["OBS_TIMESTAMP", "LAT", "LON"], "additional_variables": ["SRAD01_00001", "SRAD01_00002"], "sensor_type": "CrIS"} - ] - - # Loop through each sensor and load the dataset - for sensor in sensors: - print(f"\nTesting sensor: {sensor['name']}") - - # Create the dataset instance - dataset = SensorDataset(sensor['name'], sensor['time'], sensor['primary_descriptors'], sensor['additional_variables'], sensor_type=sensor['sensor_type']) - - # Print dataset length - print(f"Dataset length for {sensor['name']}: {len(dataset)}") - - # Retrieve and print the first item - item = dataset[0] - print(f"First item from {sensor['name']}:") - print(item) - - # Ensure the dataset item structure is correct - expected_keys = {"timestamp", "latitude", "longitude", "metadata"} - assert set(item.keys()) == expected_keys, f"Dataset item keys for {sensor['name']} are not as expected." - - # Validate tensor properties - assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor['name']}." - assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor['name']}." - assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor['name']}." - - assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor['name']}." - assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor['name']}." - assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor['name']}." - - assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor['name']}." - assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor['name']}." - assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor['name']}." - - assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor['name']}." - assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor['name']}" - - print(f"Metadata for {sensor['name']}: {item['metadata']}\n") \ No newline at end of file + assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch" + assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch" + assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch" + assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch" + assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch" + assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch" + assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch" + assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch" From 1b08799898f906b203ae351bca5ffe88d99499b3 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Sat, 22 Feb 2025 09:16:43 +0530 Subject: [PATCH 12/14] nnjai_wrapp removal --- graph_weather/data/nnjai_wrapp.py | 106 ------------------------------ 1 file changed, 106 deletions(-) delete mode 100644 graph_weather/data/nnjai_wrapp.py diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py deleted file mode 100644 index d66c9465..00000000 --- a/graph_weather/data/nnjai_wrapp.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS - -The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and -variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. -""" - -import numpy as np -import torch -from torch.utils.data import Dataset - -try: - from nnja import DataCatalog -except ImportError: - print( - "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" - ) - - -class SensorDataset(Dataset): - """A custom PyTorch Dataset for handling various sensor data.""" - - def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): - """Initialize the dataset loader for various sensors. - - Args: - dataset_name: Name of the dataset to load. - time: Specific timestamp to filter the data. - primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). - additional_variables: List of additional variables to include in metadata. - sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS) - """ - self.dataset_name = dataset_name - self.time = time - self.primary_descriptors = primary_descriptors - self.additional_variables = additional_variables - self.sensor_type = sensor_type # New argument for selecting sensor type - - # Load data catalog and dataset - self.catalog = DataCatalog(skip_manifest=True) - self.dataset = self.catalog[self.dataset_name] - self.dataset.load_manifest() - - if self.sensor_type == "AMSU": - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)] - ) - elif self.sensor_type == "ATMS": - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)] - ) - elif self.sensor_type == "MHS": - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)] - ) - elif self.sensor_type == "IASI": - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)] - ) - elif self.sensor_type == "CrIS": - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] - ) - else: - raise ValueError(f"Unsupported sensor type: {self.sensor_type}") - - self.dataframe = self.dataset.load_dataset(engine="pandas") - - for col in primary_descriptors: - if col not in self.dataframe.columns: - raise ValueError(f"The dataset must include a '{col}' column.") - - self.metadata_columns = [ - col for col in self.dataframe.columns if col not in self.primary_descriptors - ] - - def __len__(self): - """Return the total number of samples in the dataset.""" - return len(self.dataframe) - - def __getitem__(self, index): - """Return the observation and metadata for a given index.""" - row = self.dataframe.iloc[index] - time = row["OBS_TIMESTAMP"].timestamp() - latitude = row["LAT"] - longitude = row["LON"] - metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) - - return { - "timestamp": torch.tensor(time, dtype=torch.float32), - "latitude": torch.tensor(latitude, dtype=torch.float32), - "longitude": torch.tensor(longitude, dtype=torch.float32), - "metadata": torch.from_numpy(metadata), - } - - -def collate_fn(batch): - """Custom collate function to handle batching of dictionary data. - - Args: - batch: List of dictionaries from __getitem__ - - Returns: - Single dictionary with batched tensors - """ - return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} From 3c53605b6ca0a2f32d32f09ba1453ce8ab96bd50 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 24 Feb 2025 16:05:44 +0530 Subject: [PATCH 13/14] removal nnjai_wrapp.pr --- graph_weather/data/nnjai_wrapp.py | 97 ------------------------------- 1 file changed, 97 deletions(-) delete mode 100644 graph_weather/data/nnjai_wrapp.py diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py deleted file mode 100644 index 61c209d2..00000000 --- a/graph_weather/data/nnjai_wrapp.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -A custom PyTorch Dataset implementation for AMSU datasets. - -This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. -The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and -variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. -""" - -import numpy as np -import torch -from torch.utils.data import Dataset - -try: - from nnja import DataCatalog -except ImportError: - print( - "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" - ) - - -class AMSUDataset(Dataset): - """A custom PyTorch Dataset for handling AMSU data. - - This dataset retrieves observations and their metadata, filtered by the provided time and - variable descriptors. - """ - - def __init__(self, dataset_name, time, primary_descriptors, additional_variables): - """Initialize the AMSU dataset loader. - - Args: - dataset_name: Name of the dataset to load. - time: Specific timestamp to filter the data. - primary_descriptors: List of primary descriptor variables to include (e.g., - OBS_TIMESTAMP, LAT, LON). - additional_variables: List of additional variables to include in metadata. - """ - self.dataset_name = dataset_name - self.time = time - self.primary_descriptors = primary_descriptors - self.additional_variables = additional_variables - - # Load data catalog and dataset - self.catalog = DataCatalog(skip_manifest=True) - self.dataset = self.catalog[self.dataset_name] - self.dataset.load_manifest() - - self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + self.additional_variables - ) - self.dataframe = self.dataset.load_dataset(engine="pandas") - - for col in primary_descriptors: - if col not in self.dataframe.columns: - raise ValueError(f"The dataset must include a '{col}' column.") - - self.metadata_columns = [ - col for col in self.dataframe.columns if col not in self.primary_descriptors - ] - - def __len__(self): - """Return the total number of samples in the dataset.""" - return len(self.dataframe) - - def __getitem__(self, index): - """Return the observation and metadata for a given index. - - Args: - index: Index of the observation to retrieve. - - Returns: - A dictionary containing timestamp, latitude, longitude, and metadata. - """ - row = self.dataframe.iloc[index] - time = row["OBS_TIMESTAMP"].timestamp() - latitude = row["LAT"] - longitude = row["LON"] - metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) - - return { - "timestamp": torch.tensor(time, dtype=torch.float32), - "latitude": torch.tensor(latitude, dtype=torch.float32), - "longitude": torch.tensor(longitude, dtype=torch.float32), - "metadata": torch.from_numpy(metadata), - } - - -def collate_fn(batch): - """Custom collate function to handle batching of dictionary data. - - Args: - batch: List of dictionaries from __getitem__ - - Returns: - Single dictionary with batched tensors - """ - return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} From fa297f39073297ad1498b199b6192d84e57e9fd7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Feb 2025 10:49:26 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/__init__.py | 2 +- graph_weather/data/nnja_ai.py | 23 +++++--- tests/test_nnjai.py | 102 ++++++++++++++++++++++------------ 3 files changed, 84 insertions(+), 43 deletions(-) diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 9cf9ddc5..fc0ea4fb 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -2,4 +2,4 @@ from .data.nnja_ai import SensorDataset, collate_fn from .models.analysis import GraphWeatherAssimilator -from .models.forecast import GraphWeatherForecaster \ No newline at end of file +from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/nnja_ai.py b/graph_weather/data/nnja_ai.py index d66c9465..457c59a2 100755 --- a/graph_weather/data/nnja_ai.py +++ b/graph_weather/data/nnja_ai.py @@ -1,5 +1,5 @@ """ -A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS +A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. @@ -20,7 +20,9 @@ class SensorDataset(Dataset): """A custom PyTorch Dataset for handling various sensor data.""" - def __init__(self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"): + def __init__( + self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU" + ): """Initialize the dataset loader for various sensors. Args: @@ -43,23 +45,30 @@ def __init__(self, dataset_name, time, primary_descriptors, additional_variables if self.sensor_type == "AMSU": self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)] + time=self.time, + variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)], ) elif self.sensor_type == "ATMS": self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)] + time=self.time, + variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)], ) elif self.sensor_type == "MHS": self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)] + time=self.time, + variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)], ) elif self.sensor_type == "IASI": self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)] + time=self.time, + variables=self.primary_descriptors + + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)], ) elif self.sensor_type == "CrIS": self.dataset = self.dataset.sel( - time=self.time, variables=self.primary_descriptors + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] + time=self.time, + variables=self.primary_descriptors + + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)], ) else: raise ValueError(f"Unsupported sensor type: {self.sensor_type}") diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index 959a32d7..252ce4b3 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -1,9 +1,9 @@ """ -Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior. +Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior. The tests ensure correct handling of data types, shapes, and batch processing for various sensor types. """ -from datetime import datetime +from datetime import datetime from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -12,6 +12,7 @@ from graph_weather.data.nnja_ai import SensorDataset, collate_fn + def get_sensor_variables(sensor_type): """Helper function to get the correct variables for each sensor type.""" if sensor_type == "AMSU": @@ -19,13 +20,14 @@ def get_sensor_variables(sensor_type): elif sensor_type == "ATMS": return [f"TMBR_000{i:02d}" for i in range(1, 23)] # 22 channels elif sensor_type == "MHS": - return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels + return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels elif sensor_type == "IASI": return [f"SCRA_{str(i).zfill(5)}" for i in range(1, 617)] # 616 channels elif sensor_type == "CrIS": return [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] # 431 channels return [] + @pytest.fixture def mock_datacatalog(): """ @@ -45,11 +47,15 @@ def create_mock_df(engine="pandas"): sensor_vars = get_sensor_variables(mock_dataset.sensor_type) # Create DataFrame with required columns - df = pd.DataFrame({ - "OBS_TIMESTAMP": pd.date_range(start=datetime(2021, 1, 1), periods=100, freq='H'), - "LAT": np.full(100, 45.0), - "LON": np.full(100, -120.0) - }) + df = pd.DataFrame( + { + "OBS_TIMESTAMP": pd.date_range( + start=datetime(2021, 1, 1), periods=100, freq="H" + ), + "LAT": np.full(100, 45.0), + "LON": np.full(100, -120.0), + } + ) # Add sensor-specific variables for var in sensor_vars: @@ -64,8 +70,7 @@ def create_mock_df(engine="pandas"): def get_mock_dataset(self, name): # Set the sensor type based on the requested dataset name mock_dataset.sensor_type = next( - config["sensor_type"] for config in SENSOR_CONFIGS - if config["name"] == name + config["sensor_type"] for config in SENSOR_CONFIGS if config["name"] == name ) return mock_dataset @@ -74,35 +79,37 @@ def get_mock_dataset(self, name): yield mock + # Test configurations SENSOR_CONFIGS = [ { "name": "amsu-1bamua-NC021023", "sensor_type": "AMSU", - "expected_metadata_size": 15 # 15 TMBR channels + "expected_metadata_size": 15, # 15 TMBR channels }, { "name": "atms-atms-NC021203", "sensor_type": "ATMS", - "expected_metadata_size": 22 # 22 TMBR channels + "expected_metadata_size": 22, # 22 TMBR channels }, { "name": "mhs-1bmhs-NC021027", "sensor_type": "MHS", - "expected_metadata_size": 5 # 5 TMBR channels + "expected_metadata_size": 5, # 5 TMBR channels }, { "name": "iasi-mtiasi-NC021241", "sensor_type": "IASI", - "expected_metadata_size": 616 # 616 SCRA channels + "expected_metadata_size": 616, # 616 SCRA channels }, { "name": "cris-crisf4-NC021206", "sensor_type": "CrIS", - "expected_metadata_size": 431 # 431 SRAD channels - } + "expected_metadata_size": 431, # 431 SRAD channels + }, ] + @pytest.mark.parametrize("sensor_config", SENSOR_CONFIGS) def test_sensor_dataset(mock_datacatalog, sensor_config): """Test the SensorDataset class for different sensor types.""" @@ -114,7 +121,7 @@ def test_sensor_dataset(mock_datacatalog, sensor_config): time=time, primary_descriptors=primary_descriptors, additional_variables=get_sensor_variables(sensor_config["sensor_type"]), - sensor_type=sensor_config["sensor_type"] + sensor_type=sensor_config["sensor_type"], ) # Test dataset length @@ -123,25 +130,50 @@ def test_sensor_dataset(mock_datacatalog, sensor_config): # Test single item structure item = dataset[0] expected_keys = {"timestamp", "latitude", "longitude", "metadata"} - assert set(item.keys()) == expected_keys, f"Dataset item keys are not as expected for {sensor_config['sensor_type']}" + assert ( + set(item.keys()) == expected_keys + ), f"Dataset item keys are not as expected for {sensor_config['sensor_type']}" # Validate tensor properties - assert isinstance(item["timestamp"], torch.Tensor), f"Timestamp should be a tensor for {sensor_config['sensor_type']}" - assert item["timestamp"].dtype == torch.float32, f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}" - assert item["timestamp"].ndim == 0, f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}" - - assert isinstance(item["latitude"], torch.Tensor), f"Latitude should be a tensor for {sensor_config['sensor_type']}" - assert item["latitude"].dtype == torch.float32, f"Latitude should have dtype float32 for {sensor_config['sensor_type']}" - assert item["latitude"].ndim == 0, f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}" - - assert isinstance(item["longitude"], torch.Tensor), f"Longitude should be a tensor for {sensor_config['sensor_type']}" - assert item["longitude"].dtype == torch.float32, f"Longitude should have dtype float32 for {sensor_config['sensor_type']}" - assert item["longitude"].ndim == 0, f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}" - - assert isinstance(item["metadata"], torch.Tensor), f"Metadata should be a tensor for {sensor_config['sensor_type']}" - assert item["metadata"].shape == (sensor_config["expected_metadata_size"],), \ - f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)" - assert item["metadata"].dtype == torch.float32, f"Metadata should have dtype float32 for {sensor_config['sensor_type']}" + assert isinstance( + item["timestamp"], torch.Tensor + ), f"Timestamp should be a tensor for {sensor_config['sensor_type']}" + assert ( + item["timestamp"].dtype == torch.float32 + ), f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}" + assert ( + item["timestamp"].ndim == 0 + ), f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance( + item["latitude"], torch.Tensor + ), f"Latitude should be a tensor for {sensor_config['sensor_type']}" + assert ( + item["latitude"].dtype == torch.float32 + ), f"Latitude should have dtype float32 for {sensor_config['sensor_type']}" + assert ( + item["latitude"].ndim == 0 + ), f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance( + item["longitude"], torch.Tensor + ), f"Longitude should be a tensor for {sensor_config['sensor_type']}" + assert ( + item["longitude"].dtype == torch.float32 + ), f"Longitude should have dtype float32 for {sensor_config['sensor_type']}" + assert ( + item["longitude"].ndim == 0 + ), f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}" + + assert isinstance( + item["metadata"], torch.Tensor + ), f"Metadata should be a tensor for {sensor_config['sensor_type']}" + assert item["metadata"].shape == ( + sensor_config["expected_metadata_size"], + ), f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)" + assert ( + item["metadata"].dtype == torch.float32 + ), f"Metadata should have dtype float32 for {sensor_config['sensor_type']}" def test_collate_function(): @@ -167,4 +199,4 @@ def test_collate_function(): assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch" assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch" assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch" - assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch" \ No newline at end of file + assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch"