Skip to content

Commit

Permalink
Merge pull request #1546 from mraspaud/feature-compact-viirs-distributed
Browse files Browse the repository at this point in the history
Make viirs-compact datasets compatible with dask distributed
  • Loading branch information
mraspaud authored Feb 16, 2021
2 parents 50ddcc6 + c1171a3 commit 2245c9a
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 141 deletions.
237 changes: 114 additions & 123 deletions satpy/readers/viirs_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,35 @@
"""

import logging
from contextlib import suppress
from datetime import datetime, timedelta

import dask.array as da
import h5py
import numpy as np
import xarray as xr
import dask.array as da

from satpy import CHUNK_SIZE
from satpy.readers.file_handlers import BaseFileHandler
from satpy.readers.utils import np2str
from satpy.utils import angle2xyz, lonlat2xyz, xyz2angle, xyz2lonlat
from satpy import CHUNK_SIZE

chans_dict = {"M01": "M1",
"M02": "M2",
"M03": "M3",
"M04": "M4",
"M05": "M5",
"M06": "M6",
"M07": "M7",
"M08": "M8",
"M09": "M9",
"M10": "M10",
"M11": "M11",
"M12": "M12",
"M13": "M13",
"M14": "M14",
"M15": "M15",
"M16": "M16",
"DNB": "DNB"}
_channels_dict = {"M01": "M1",
"M02": "M2",
"M03": "M3",
"M04": "M4",
"M05": "M5",
"M06": "M6",
"M07": "M7",
"M08": "M8",
"M09": "M9",
"M10": "M10",
"M11": "M11",
"M12": "M12",
"M13": "M13",
"M14": "M14",
"M15": "M15",
"M16": "M16",
"DNB": "DNB"}

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,40 +99,35 @@ def __init__(self, filename, filename_info, filetype_info):
or (max(abs(self.min_lat), abs(self.max_lat)) > 60))

self.scans = self.h5f["All_Data"]["NumberOfScans"][0]
self.geostuff = self.h5f["All_Data"]['VIIRS-%s-GEO_All' % self.ch_type]
self.geography = self.h5f["All_Data"]['VIIRS-%s-GEO_All' % self.ch_type]

for key in self.h5f["All_Data"].keys():
if key.startswith("VIIRS") and key.endswith("SDR_All"):
channel = key.split('-')[1]
break

# FIXME: this supposes there is only one tiepoint zone in the
# track direction
self.scan_size = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["TiePointZoneSizeTrack"].item()
self.track_offset = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["PixelOffsetTrack"]
self.scan_offset = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["PixelOffsetScan"]
# This supposes there is only one tiepoint zone in the track direction.
channel_path = f"All_Data/VIIRS-{channel}-SDR_All"
self.scan_size = self.h5f[channel_path].attrs["TiePointZoneSizeTrack"].item()
self.track_offset = self.h5f[channel_path].attrs["PixelOffsetTrack"][()]
self.scan_offset = self.h5f[channel_path].attrs["PixelOffsetScan"][()]

try:
self.group_locations = self.geostuff[
"TiePointZoneGroupLocationScanCompact"][()]
self.group_locations = self.geography["TiePointZoneGroupLocationScanCompact"][()]
except KeyError:
self.group_locations = [0]

self.tpz_sizes = da.from_array(self.h5f["All_Data/VIIRS-%s-SDR_All" % channel].attrs["TiePointZoneSizeScan"],
chunks=1)
self.tpz_sizes = da.from_array(self.h5f[channel_path].attrs["TiePointZoneSizeScan"], chunks=1)
if len(self.tpz_sizes.shape) == 2:
if self.tpz_sizes.shape[1] != 1:
raise NotImplementedError("Can't handle 2 dimensional tiepoint zones.")
self.tpz_sizes = self.tpz_sizes.squeeze(1)
self.nb_tpzs = self.geostuff["NumberOfTiePointZonesScan"]
self.c_align = da.from_array(self.geostuff["AlignmentCoefficient"],
chunks=tuple(self.nb_tpzs))
self.c_exp = da.from_array(self.geostuff["ExpansionCoefficient"],
chunks=tuple(self.nb_tpzs))
self.nb_tpzs = da.from_array(self.nb_tpzs, chunks=1)
self.nb_tiepoint_zones = self.geography["NumberOfTiePointZonesScan"][()]
self.c_align = da.from_array(self.geography["AlignmentCoefficient"],
chunks=tuple(self.nb_tiepoint_zones))
self.c_exp = da.from_array(self.geography["ExpansionCoefficient"],
chunks=tuple(self.nb_tiepoint_zones))
self.nb_tiepoint_zones = da.from_array(self.nb_tiepoint_zones, chunks=1)
self._expansion_coefs = None

self.cache = {}
Expand All @@ -144,15 +139,13 @@ def __init__(self, filename, filename_info, filetype_info):

def __del__(self):
"""Close file handlers when we are done."""
try:
with suppress(OSError):
self.h5f.close()
except OSError:
pass

def get_dataset(self, key, info):
"""Load a dataset."""
logger.debug('Reading %s.', key['name'])
if key['name'] in chans_dict:
if key['name'] in _channels_dict:
m_data = self.read_dataset(key, info)
else:
m_data = self.read_geo(key, info)
Expand All @@ -164,10 +157,8 @@ def get_bounding_box(self):
"""Get the bounding box of the data."""
for key in self.h5f["Data_Products"].keys():
if key.startswith("VIIRS") and key.endswith("GEO"):
lats = self.h5f["Data_Products"][key][
key + '_Gran_0'].attrs['G-Ring_Latitude']
lons = self.h5f["Data_Products"][key][
key + '_Gran_0'].attrs['G-Ring_Longitude']
lats = self.h5f["Data_Products"][key][key + '_Gran_0'].attrs['G-Ring_Latitude'][()]
lons = self.h5f["Data_Products"][key][key + '_Gran_0'].attrs['G-Ring_Longitude'][()]
break
else:
raise KeyError('Cannot find bounding coordinates!')
Expand Down Expand Up @@ -214,8 +205,6 @@ def read_geo(self, key, info):
attrs=self.mda, dims=('y', 'x'))

if info.get('standard_name') in ['latitude', 'longitude']:
if self.lons is None or self.lats is None:
self.lons, self.lats = self.navigate()
mda = self.mda.copy()
mda.update(info)
if info['standard_name'] == 'longitude':
Expand All @@ -226,13 +215,13 @@ def read_geo(self, key, info):
if key['name'] == 'dnb_moon_illumination_fraction':
mda = self.mda.copy()
mda.update(info)
return xr.DataArray(da.from_array(self.geostuff["MoonIllumFraction"]),
return xr.DataArray(da.from_array(self.geography["MoonIllumFraction"]),
attrs=info)

def read_dataset(self, dataset_key, info):
"""Read a dataset."""
h5f = self.h5f
channel = chans_dict[dataset_key['name']]
channel = _channels_dict[dataset_key['name']]
chan_dict = dict([(key.split("-")[1], key)
for key in h5f["All_Data"].keys()
if key.startswith("VIIRS")])
Expand All @@ -245,13 +234,6 @@ def read_dataset(self, dataset_key, info):
h5attrs = h5rads.attrs
scans = h5f["All_Data"]["NumberOfScans"][0]
rads = rads[:scans * 16, :]
# if channel in ("M9", ):
# arr = rads[:scans * 16, :].astype(np.float32)
# arr[arr > 65526] = np.nan
# arr = np.ma.masked_array(arr, mask=arr_mask)
# else:
# arr = np.ma.masked_greater(rads[:scans * 16, :].astype(np.float32),
# 65526)
rads = rads.where(rads <= 65526)
try:
rads = xr.where(rads <= h5attrs['Threshold'],
Expand Down Expand Up @@ -299,79 +281,38 @@ def read_dataset(self, dataset_key, info):
rads.attrs['units'] = unit
return rads

def expand(self, data, coefs):
"""Perform the expansion in numpy domain."""
data = data.reshape(data.shape[:-1])

coefs = coefs.reshape(self.scans, self.scan_size, data.shape[1] - 1, -1, 4)

coef_a = coefs[:, :, :, :, 0]
coef_b = coefs[:, :, :, :, 1]
coef_c = coefs[:, :, :, :, 2]
coef_d = coefs[:, :, :, :, 3]

data_a = data[:self.scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:self.scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:self.scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:self.scans * 2:2, np.newaxis, :-1, np.newaxis]

fdata = (coef_a * data_a + coef_b * data_b + coef_d * data_d + coef_c * data_c)

return fdata.reshape(self.scans * self.scan_size, -1)

def expand_angle_and_nav(self, arrays):
"""Expand angle and navigation datasets."""
res = []
for array in arrays:
res.append(da.map_blocks(self.expand, array[:, :, np.newaxis], self.expansion_coefs,
res.append(da.map_blocks(expand, array[:, :, np.newaxis], self.expansion_coefs,
scans=self.scans, scan_size=self.scan_size,
dtype=array.dtype, drop_axis=2, chunks=self.expansion_coefs.chunks[:-1]))
return res

def get_coefs(self, c_align, c_exp, tpz_size, nb_tpz, v_track):
"""Compute the coeffs in numpy domain."""
nties = nb_tpz.item()
tpz_size = tpz_size.item()
v_scan = (np.arange(nties * tpz_size) % tpz_size + self.scan_offset) / tpz_size
s_scan, s_track = np.meshgrid(v_scan, v_track)
s_track = s_track.reshape(self.scans, self.scan_size, nties, tpz_size)
s_scan = s_scan.reshape(self.scans, self.scan_size, nties, tpz_size)

c_align = c_align[np.newaxis, np.newaxis, :, np.newaxis]
c_exp = c_exp[np.newaxis, np.newaxis, :, np.newaxis]

a_scan = s_scan + s_scan * (1 - s_scan) * c_exp + s_track * (
1 - s_track) * c_align
a_track = s_track
coef_a = (1 - a_track) * (1 - a_scan)
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
res = np.stack([coef_a, coef_b, coef_c, coef_d], axis=4).reshape(self.scans * self.scan_size, -1, 4)
return res

@property
def expansion_coefs(self):
"""Compute the expansion coefficients."""
if self._expansion_coefs is not None:
return self._expansion_coefs
v_track = (np.arange(self.scans * self.scan_size) % self.scan_size + self.track_offset) / self.scan_size
self.tpz_sizes = self.tpz_sizes.persist()
self.nb_tpzs = self.nb_tpzs.persist()
col_chunks = (self.tpz_sizes * self.nb_tpzs).compute()
self._expansion_coefs = da.map_blocks(self.get_coefs, self.c_align, self.c_exp, self.tpz_sizes, self.nb_tpzs,
dtype=np.float64, v_track=v_track, new_axis=[0, 2],
chunks=(self.scans * self.scan_size,
tuple(col_chunks), 4))
self.nb_tiepoint_zones = self.nb_tiepoint_zones.persist()
col_chunks = (self.tpz_sizes * self.nb_tiepoint_zones).compute()
self._expansion_coefs = da.map_blocks(get_coefs, self.c_align, self.c_exp, self.tpz_sizes,
self.nb_tiepoint_zones,
v_track=v_track, scans=self.scans, scan_size=self.scan_size,
scan_offset=self.scan_offset,
dtype=np.float64, new_axis=[0, 2],
chunks=(self.scans * self.scan_size, tuple(col_chunks), 4))

return self._expansion_coefs

def navigate(self):
"""Generate the navigation datasets."""
shape = self.geostuff['Longitude'].shape
hchunks = (self.nb_tpzs + 1).compute()
chunks = (shape[0], tuple(hchunks))
lon = da.from_array(self.geostuff["Longitude"], chunks=chunks)
lat = da.from_array(self.geostuff["Latitude"], chunks=chunks)
chunks = self._get_geographical_chunks()
lon = da.from_array(self.geography["Longitude"], chunks=chunks)
lat = da.from_array(self.geography["Latitude"], chunks=chunks)
if self.switch_to_cart:
arrays = lonlat2xyz(lon, lat)
else:
Expand All @@ -383,14 +324,18 @@ def navigate(self):

return expanded

def _get_geographical_chunks(self):
shape = self.geography['Longitude'].shape
horizontal_chunks = (self.nb_tiepoint_zones + 1).compute()
chunks = (shape[0], tuple(horizontal_chunks))
return chunks

def angles(self, azi_name, zen_name):
"""Generate the angle datasets."""
shape = self.geostuff['Longitude'].shape
hchunks = (self.nb_tpzs + 1).compute()
chunks = (shape[0], tuple(hchunks))
chunks = self._get_geographical_chunks()

azi = self.geostuff[azi_name]
zen = self.geostuff[zen_name]
azi = self.geography[azi_name]
zen = self.geography[zen_name]

switch_to_cart = ((np.max(azi) - np.min(azi) > 5)
or (np.min(zen) < 10)
Expand Down Expand Up @@ -433,6 +378,56 @@ def convert_to_angles(x, y, z):
return azi, zen


def get_coefs(c_align, c_exp, tpz_size, nb_tpz, v_track, scans, scan_size, scan_offset):
"""Compute the coeffs in numpy domain."""
nties = nb_tpz.item()
tpz_size = tpz_size.item()
v_scan = (np.arange(nties * tpz_size) % tpz_size + scan_offset) / tpz_size
s_scan, s_track = np.meshgrid(v_scan, v_track)
s_track = s_track.reshape(scans, scan_size, nties, tpz_size)
s_scan = s_scan.reshape(scans, scan_size, nties, tpz_size)

c_align = c_align[np.newaxis, np.newaxis, :, np.newaxis]
c_exp = c_exp[np.newaxis, np.newaxis, :, np.newaxis]

a_scan = s_scan + s_scan * (1 - s_scan) * c_exp + s_track * (
1 - s_track) * c_align
a_track = s_track
coef_a = (1 - a_track) * (1 - a_scan)
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
res = np.stack([coef_a, coef_b, coef_c, coef_d], axis=4).reshape(scans * scan_size, -1, 4)
return res


def expand(data, coefs, scans, scan_size):
"""Perform the expansion in numpy domain."""
data = data.reshape(data.shape[:-1])

coefs = coefs.reshape(scans, scan_size, data.shape[1] - 1, -1, 4)

coef_a = coefs[:, :, :, :, 0]
coef_b = coefs[:, :, :, :, 1]
coef_c = coefs[:, :, :, :, 2]
coef_d = coefs[:, :, :, :, 3]

corner_coefficients = (coef_a, coef_b, coef_c, coef_d)
fdata = _interpolate_data(data, corner_coefficients, scans)
return fdata.reshape(scans * scan_size, -1)


def _interpolate_data(data, corner_coefficients, scans):
"""Interpolate the data using the provided coefficients."""
coef_a, coef_b, coef_c, coef_d = corner_coefficients
data_a = data[:scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:scans * 2:2, np.newaxis, :-1, np.newaxis]
fdata = (coef_a * data_a + coef_b * data_b + coef_d * data_d + coef_c * data_c)
return fdata


def expand_arrays(arrays,
scans,
c_align,
Expand Down Expand Up @@ -460,12 +455,8 @@ def expand_arrays(arrays,
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
corner_coefficients = (coef_a, coef_b, coef_c, coef_d)
for data in arrays:
data_a = data[:scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:scans * 2:2, np.newaxis, :-1, np.newaxis]
fdata = (coef_a * data_a + coef_b * data_b
+ coef_d * data_d + coef_c * data_c)
fdata = _interpolate_data(data, corner_coefficients, scans)
expanded.append(fdata.reshape(scans * scan_size, nties * tpz_size))
return expanded
Loading

0 comments on commit 2245c9a

Please sign in to comment.