diff --git a/tests/test_accessors.py b/tests/test_accessors.py index a64db416..794c603f 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -79,6 +79,11 @@ def test_dataarray_destagger(test_grid): xr.testing.assert_allclose(destaggered['XLAT'], test_grid['XLAT']) xr.testing.assert_allclose(destaggered['XLONG'], test_grid['XLONG']) + # Check attributes are preserved + assert set(destaggered.attrs.keys()) == set(data.attrs.keys()) - { + 'stagger', + } + @pytest.mark.parametrize('test_grid', ['lambert_conformal', 'mercator'], indirect=True) def test_dataarray_destagger_with_exclude(test_grid): @@ -110,5 +115,12 @@ def test_dataset_destagger(test_grid): or destaggered[varname].attrs['stagger'] == '' ) + # Check preservation of variable attrs + for varname in set(test_grid.data_vars).intersection(set(destaggered.data_vars)): + # because of xwrf.postprocess, the destaggered attrs will include more information + assert set(test_grid[varname].attrs.keys()) - {'stagger', 'units'} <= set( + destaggered[varname].attrs.keys() + ) + # Check that attrs are preserved assert destaggered.attrs == test_grid.attrs diff --git a/tests/test_destagger.py b/tests/test_destagger.py index 695068f8..f79c968d 100644 --- a/tests/test_destagger.py +++ b/tests/test_destagger.py @@ -63,14 +63,16 @@ def test_destag_variable_multiple_dims(): ], ) def test_destag_variable_1d(unstag_dim_name, expected_output_dim_name): - staggered = xr.Variable(('bottom_top_stag',), np.arange(5), attrs={'stagger': 'Z'}) + staggered = xr.Variable( + ('bottom_top_stag',), np.arange(5), attrs={'foo': 'bar', 'stagger': 'Z'} + ) output = _destag_variable(staggered, unstag_dim_name=unstag_dim_name) # Check values np.testing.assert_array_almost_equal(output.values, 0.5 + np.arange(4)) # Check dim name assert output.dims[0] == expected_output_dim_name # Check attrs - assert not output.attrs + assert output.attrs == {'foo': 'bar'} def test_destag_variable_2d(): diff --git a/xwrf/destagger.py b/xwrf/destagger.py index 9752db87..b8f539c0 100644 --- a/xwrf/destagger.py +++ b/xwrf/destagger.py @@ -73,8 +73,8 @@ def _destag_variable(datavar, stagger_dim=None, unstag_dim_name=None): return xr.Variable( dims=tuple(str(unstag_dim_name) if dim == stagger_dim else dim for dim in center_mean.dims), data=center_mean.data, - attrs=_drop_attrs(center_mean.attrs, ('stagger',)), - encoding=center_mean.encoding, + attrs=_drop_attrs(datavar.attrs, ('stagger',)), + encoding=datavar.encoding, fastpath=True, )