Skip to content

Commit

Permalink
optimise top boundary processing RAM usage
Browse files Browse the repository at this point in the history
  • Loading branch information
dongqi-DQ committed May 9, 2022
1 parent 832e302 commit d6fc6e2
Showing 1 changed file with 38 additions and 51 deletions.
89 changes: 38 additions & 51 deletions run_config_wrf4palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from multiprocess import Pool
from dynamic_util.nearest import framing_2d_cartesian
from dynamic_util.loc_dom import calc_stretch, domain_location, generate_cfg
from dynamic_util.process_wrf import zinterp, multi_zinterp, process_top
from dynamic_util.process_wrf import zinterp, multi_zinterp
from dynamic_util.geostrophic import calc_geostrophic_wind
from dynamic_util.surface_nan_solver import *
import warnings
Expand Down Expand Up @@ -346,20 +346,20 @@
#-------------------------------------------------------------------------------
print("Start vertical interpolation")
# create an empty dataset to store interpolated data
print("ds_we")
print("create empty datasets")
ds_we = ds_interp.isel(west_east=[0,-1])
ds_sn = ds_interp.isel(south_north=[0,-1])

print("ds_we_ustag")
print("create empty datasets for staggered U and V (west&east boundaries)")
ds_we_ustag = ds_interp_u.isel(west_east=[0,-1])
ds_we_vstag = ds_interp_v.isel(west_east=[0,-1])

print("ds_sn_ustag")
print("create empty datasets for staggered U and V (south&north boundaries)")
ds_sn_ustag = ds_interp_u.isel(south_north=[0,-1])
ds_sn_vstag = ds_interp_v.isel(south_north=[0,-1])

varbc_list = ["W", "QVAPOR","pt","Z"]
print("loop")
print("remove unused vars from datasets")
for var in ds_we.data_vars:
if var not in varbc_list:
ds_we = ds_we.drop(var)
Expand All @@ -371,30 +371,29 @@
ds_we_vstag = ds_we_vstag.drop(var)
ds_sn_vstag = ds_sn_vstag.drop(var)

print("load ds_we")
print("load dataset for west&east boundaries")
ds_we = ds_we.load()
print("load ds_sn")
print("load dataset for south&north boundaries")
ds_sn = ds_sn.load()

print("load ds_we_ustag")
print("load dataset for west&east boundaries (staggered U)")
ds_we_ustag = ds_we_ustag.load()
print("load ds_sn_ustag")
print("load dataset for south&north boundaries (staggered U)")
ds_sn_ustag = ds_sn_ustag.load()

print("load ds_we_vstag")
print("load dataset for west&east boundaries (staggered V)")
ds_we_vstag = ds_we_vstag.load()
print("load ds_sn_vstag")
print("load dataset for south&north boundaries (staggered V)")
ds_sn_vstag = ds_sn_vstag.load()

print("ds_palm_we")
print("create datasets to save data in PALM coordinates")
ds_palm_we = xr.Dataset()
ds_palm_we = ds_palm_we.assign_coords({"x": x[:2],"y": y, "time":ds_interp.time.data,
"z": z, "yv": yv, "xu": xu[:2], "zw":zw})
print("ds_palm_sn")
ds_palm_sn = xr.Dataset()
ds_palm_sn = ds_palm_sn.assign_coords({"x": x,"y": y[:2], "time":ds_interp.time.data,
"z": z, "yv": yv[:2], "xu": xu, "zw":zw})
print("zeros_we")
print("create zeros arrays for vertical interpolation")
zeros_we = np.zeros((len(all_ts), len(z), len(y), len(x[:2])))
zeros_sn = np.zeros((len(all_ts), len(z), len(y[:2]), len(x)))

Expand All @@ -413,30 +412,30 @@
ds_palm_we["W"] = xr.DataArray(np.copy(zeros_we_w), dims=['time','zw','y', 'x'])
ds_palm_sn["W"] = xr.DataArray(np.copy(zeros_sn_w), dims=['time','zw','y', 'x'])

print(f"Processing W for west and east boundaries")
print("Processing W for west and east boundaries")
ds_palm_we["W"] = multi_zinterp(max_pool, ds_we, "W", zw, ds_palm_we)
print(f"Processing W for south and north boundaries")
print("Processing W for south and north boundaries")
ds_palm_sn["W"] = multi_zinterp(max_pool, ds_sn, "W", zw, ds_palm_sn)

# interpolate u and v
zeros_we_u = np.zeros((len(all_ts), len(z), len(y), len(xu[:2])))
zeros_sn_u = np.zeros((len(all_ts), len(z), len(y[:2]), len(xu)))
ds_palm_we["U"] = xr.DataArray(np.copy(zeros_we_u), dims=['time','z','y', 'xu'])
print(f"Processing U for west and east boundaries")
print("Processing U for west and east boundaries")
ds_palm_we["U"] = multi_zinterp(max_pool, ds_we_ustag, "U", z, ds_palm_we)

ds_palm_sn["U"] = xr.DataArray(np.copy(zeros_sn_u), dims=['time','z','y', 'xu'])
print(f"Processing U for south and north boundaries")
print("Processing U for south and north boundaries")
ds_palm_sn["U"] = multi_zinterp(max_pool, ds_sn_ustag, "U", z, ds_palm_sn)

zeros_we_v = np.zeros((len(all_ts), len(z), len(yv), len(x[:2])))
zeros_sn_v = np.zeros((len(all_ts), len(z), len(yv[:2]), len(x)))
ds_palm_we["V"] = xr.DataArray(np.copy(zeros_we_v), dims=['time','z','yv', 'x'])
print(f"Processing V for west and east boundaries")
print("Processing V for west and east boundaries")
ds_palm_we["V"] = multi_zinterp(max_pool, ds_we_vstag, "V", z, ds_palm_we)

ds_palm_sn["V"] = xr.DataArray(np.copy(zeros_sn_v), dims=['time','z','yv', 'x'])
print(f"Processing V for south and north boundaries")
print("Processing V for south and north boundaries")
ds_palm_sn["V"] = multi_zinterp(max_pool, ds_sn_vstag, "V", z, ds_palm_sn)
#-------------------------------------------------------------------------------
# top boundary
Expand All @@ -447,44 +446,32 @@
w_top = np.zeros((len(all_ts), len(y), len(x)))
qv_top = np.zeros((len(all_ts), len(y), len(x)))
pt_top = np.zeros((len(all_ts), len(y), len(x)))
print("Processing top boundary conditions...loop")

for var in ds_interp.data_vars:
if var not in varbc_list:
ds_interp = ds_interp.drop(var)
if var not in ["U", "Z"]:
ds_interp_u = ds_interp_u.drop(var)
if var not in ["V", "Z"]:
ds_interp_v = ds_interp_v.drop(var)
print("Processing top boundary conditions...load")
print("Processing top boundary conditions...load.ds_interp")
ds_interp = ds_interp.load()
print("Processing top boundary conditions...load.ds_interp_u")
ds_interp_u = ds_interp_u.load()
print("Processing top boundary conditions...load.ds_interp_v")
ds_interp_v = ds_interp_v.load()


top_dict = {"U": (ds_interp_u, u_top, z),
"V": (ds_interp_v, v_top, z),
"pt": (ds_interp, pt_top, z),
"QVAPOR": (ds_interp, qv_top, z),
"W": (ds_interp, w_top, zw)}
print("Processing top boundary conditions...pool")
with Pool(max_pool) as p:
pool_outputs = list(
tqdm(
p.imap(partial(process_top, all_ts,top_dict), top_dict.keys()), total=len(top_dict.keys()),
position=0, leave=True
)
)
p.join()
## convert dictionary back to dataset
pool_dict = dict(pool_outputs)
u_top = pool_dict["U"]
v_top = pool_dict["V"]
w_top = pool_dict["W"]
qv_top = pool_dict["QVAPOR"]
pt_top = pool_dict["pt"]

print("Processing top boundary datasets...")
ds_interp_top = xr.Dataset()
ds_interp_u_top = xr.Dataset()
ds_interp_v_top = xr.Dataset()
for var in ["QVAPOR", "pt"]:
ds_interp_top[var] = ds_interp.salem.wrf_zlevel(var, levels=z[-1]).copy()

ds_interp_top["W"] = ds_interp.salem.wrf_zlevel("W", levels=zw[-1]).copy()
ds_interp_u_top["U"] = ds_interp_u.salem.wrf_zlevel("U", levels=z[-1]).copy()
ds_interp_v_top["V"] = ds_interp_v.salem.wrf_zlevel("V", levels=z[-1]).copy()

for ts in tqdm(range(0,len(all_ts)), total=len(all_ts), position=0, leave=True):
u_top[ts,:,:] = ds_interp_u_top["U"].isel(time=ts)
v_top[ts,:,:] = ds_interp_v_top["V"].isel(time=ts)
w_top[ts,:,:] = ds_interp_top["W"].isel(time=ts)
pt_top[ts,:,:] = ds_interp_top["pt"].isel(time=ts)
qv_top[ts,:,:] = ds_interp_top["QVAPOR"].isel(time=ts)
#-------------------------------------------------------------------------------
# Geostrophic wind estimation
#-------------------------------------------------------------------------------
Expand Down

1 comment on commit d6fc6e2

@dongqi-DQ
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the top boundary processing won't load the entire ds. Only data at the top boundary level are loaded. This can reduce around 40% of RAM usage.
Also changed several printing lines to make the messages clearer.

Please sign in to comment.