Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved mask sampler #2890

Merged
merged 9 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Improve mask sampler by adding an MPI step and a LS_chunk (intermediate step)

### Fixed

### Removed
Expand Down
8 changes: 4 additions & 4 deletions gridcomps/History/MAPL_HistoryGridComp.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2426,8 +2426,10 @@ subroutine Initialize ( gc, import, dumexport, clock, rc )
call list(n)%trajectory%initialize(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
IntState%stampoffset(n) = list(n)%trajectory%epoch_frequency
elseif (list(n)%sampler_spec == 'mask') then
call MAPL_TimerOn(GENSTATE,"mask_init")
list(n)%mask_sampler = MaskSamplerGeosat(cfg,string,clock,genstate=GENSTATE,_RC)
call list(n)%mask_sampler%initialize(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
call MAPL_TimerOff(GENSTATE,"mask_init")
elseif (list(n)%sampler_spec == 'station') then
list(n)%station_sampler = StationSampler (list(n)%bundle, trim(list(n)%stationIdFile), nskip_line=list(n)%stationSkipLine, genstate=GENSTATE, _RC)
call list(n)%station_sampler%add_metadata_route_handle(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
Expand Down Expand Up @@ -3706,11 +3708,9 @@ subroutine Run ( gc, import, export, clock, rc )
call MAPL_TimerOff(GENSTATE,"Station")
elseif (list(n)%sampler_spec == 'mask') then
call ESMF_ClockGet(clock,currTime=current_time,_RC)
call MAPL_TimerOn(GENSTATE,"Mask")
call MAPL_TimerOn(GENSTATE,"AppendFile")
call MAPL_TimerOn(GENSTATE,"Mask_append")
call list(n)%mask_sampler%append_file(current_time,_RC)
call MAPL_TimerOff(GENSTATE,"AppendFile")
call MAPL_TimerOff(GENSTATE,"Mask")
call MAPL_TimerOff(GENSTATE,"Mask_append")
endif


Expand Down
1 change: 1 addition & 0 deletions gridcomps/History/Sampler/MAPL_GeosatMaskMod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module MaskSamplerGeosatMod
use pFIO_FileMetadataMod, only : FileMetadata
use pFIO_NetCDF4_FileFormatterMod, only : NetCDF4_FileFormatter
use MAPL_GenericMod, only : MAPL_MetaComp, MAPL_TimerOn, MAPL_TimerOff
use MPI, only : MPI_INTEGER, MPI_REAL, MPI_REAL8
use, intrinsic :: iso_fortran_env, only: REAL32
use, intrinsic :: iso_fortran_env, only: REAL64
use pflogger, only: Logger, logging
Expand Down
201 changes: 138 additions & 63 deletions gridcomps/History/Sampler/MAPL_GeosatMaskMod_smod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module function MaskSamplerGeosat_from_config(config,string,clock,GENSTATE,rc) r
mask%clock=clock
mask%grid_file_name=''
if (present(GENSTATE)) mask%GENSTATE => GENSTATE

call ESMF_ClockGet ( clock, CurrTime=currTime, _RC )
if (mapl_am_I_root()) write(6,*) 'string', string

Expand Down Expand Up @@ -159,13 +159,13 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
integer, optional, intent(out) :: rc

type(Logger), pointer :: lgr
real(ESMF_KIND_R8), pointer :: ptAT(:)
type(ESMF_routehandle) :: RH
type(ESMF_Grid) :: grid
integer :: mypet, npes
integer :: mypet, petcount, mpic
integer :: iroot, rootpet, ierr
type (ESMF_LocStream) :: LS_rt
type (ESMF_LocStream) :: LS_ds
type (ESMF_LocStream) :: LS_chunk
type (LocStreamFactory):: locstream_factory
type (ESMF_Field) :: fieldA
type (ESMF_Field) :: fieldB
Expand All @@ -182,13 +182,11 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
type(ESMF_DElayout) :: layout
type(ESMF_VM) :: VM
integer :: myid
integer :: ndes
integer :: dimCount
integer, allocatable :: II(:)
integer, allocatable :: JJ(:)
real(REAL64), allocatable :: obs_lons(:)
real(REAL64), allocatable :: obs_lats(:)
integer :: mpic

type (ESMF_Field) :: fieldI4
type(ESMF_routehandle) :: RH_halo
Expand Down Expand Up @@ -227,17 +225,34 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
integer :: nsend
integer, allocatable :: recvcounts_loc(:)
integer, allocatable :: displs_loc(:)
integer :: status

integer, allocatable :: sendcount(:), displs(:)
integer :: recvcount
integer :: M, N, ip
integer :: nx2

real(REAL64), allocatable :: lons_chunk(:)
real(REAL64), allocatable :: lats_chunk(:)

integer :: status, imethod


lgr => logging%get_logger('HISTORY.sampler')

! Metacode:
! read ABI grid into LS_rt
! gen LS_ds with CS background grid
! read ABI grid into lons/lats, lons_chunk/lats_chunk
! gen LS_chunk and LS_ds with CS background grid
! find mask points on each PET with halo
! prepare recvcounts + displs for gatherv
!

call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=petcount, localpet=mypet, _RC)
iroot = 0
ip = mypet ! 0 to M-1
M = petCount

call MAPL_TimerOn(this%GENSTATE,"1_genABIgrid")
if (mapl_am_i_root()) then
! __s1. SAT file
!
Expand All @@ -247,100 +262,156 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
key_p = this%var_name_proj
key_p_att = this%att_name_proj
call get_ncfile_dimension(fn,nlon=n1,nlat=n2,key_lon=key_x,key_lat=key_y,_RC)
!
! use thin_factor to reduce regridding matrix size
!
xdim_true = n1
ydim_true = n2
xdim_red = n1 / this%thin_factor
ydim_red = n2 / this%thin_factor
allocate (x (xdim_true), _STAT )
allocate (y (xdim_true), _STAT )

allocate (x(n1), y(n2), _STAT)
call get_v1d_netcdf_R8_complete (fn, key_x, x, _RC)
call get_v1d_netcdf_R8_complete (fn, key_y, y, _RC)
call get_att_real_netcdf (fn, key_p, key_p_att, lambda0_deg, _RC)
lam_sat = lambda0_deg * MAPL_DEGREES_TO_RADIANS_R8
end if
call MAPL_CommsBcast(vm, DATA=n1, N=1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=n2, N=1, ROOT=MAPL_Root, _RC)
if ( .NOT. mapl_am_i_root() ) allocate (x(n1), y(n2), _STAT)
call MAPL_CommsBcast(vm, DATA=lam_sat, N=1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=x, N=n1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=y, N=n2, ROOT=MAPL_Root, _RC)

!
! use thin_factor to reduce regridding matrix size
!
xdim_red = n1 / this%thin_factor
ydim_red = n2 / this%thin_factor
_ASSERT ( xdim_red * ydim_red > M, 'mask reduced points after thin_factor is less than Nproc!')

nx=0
do i=1, xdim_red
do j=1, ydim_red
! get nx2
nx2=0
k=0
do i=1, xdim_red
do j=1, ydim_red
k = k + 1
if ( mod(k,M) == ip ) then
x0 = x( i * this%thin_factor )
y0 = y( j * this%thin_factor )
call ABI_XY_2_lonlat (x0, y0, lam_sat, lon0, lat0, mask=mask0)
if (mask0 > 0) then
nx=nx+1
nx2=nx2+1
end if
end do
end if
end do
allocate (lons(nx), lats(nx), _STAT)
nx = 0
do i=1, xdim_red
do j=1, ydim_red
end do
allocate (lons_chunk(nx2), lats_chunk(nx2), _STAT)

! get lons_chunk/...
nx2 = 0
k = 0
do i=1, xdim_red
do j=1, ydim_red
k = k + 1
if ( mod(k,M) == ip ) then
x0 = x( i * this%thin_factor )
y0 = y( j * this%thin_factor )
call ABI_XY_2_lonlat (x0, y0, lam_sat, lon0, lat0, mask=mask0)
if (mask0 > 0) then
nx=nx+1
lons(nx) = lon0 * MAPL_RADIANS_TO_DEGREES
lats(nx) = lat0 * MAPL_RADIANS_TO_DEGREES
nx2=nx2+1
lons_chunk(nx2) = lon0 * MAPL_RADIANS_TO_DEGREES
lats_chunk(nx2) = lat0 * MAPL_RADIANS_TO_DEGREES
end if
end do
end if
end do
arr(1)=nx
else
allocate(lons(0),lats(0),_STAT)
arr(1)=0
endif
end do

call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=npes, localpet=mypet, _RC)
arr(1)=nx2
call ESMF_VMAllFullReduce(vm, sendData=arr, recvData=nx, &
count=1, reduceflag=ESMF_REDUCE_SUM, _RC)
this%nobs = nx
if (mapl_am_I_root()) write(6,*) 'nobs tot :', nx

if ( nx == 0 ) then
this%is_valid = .false.
_RETURN(ESMF_SUCCESS)
!
! no valid obs points are found
!

! gatherV for lons/lats
if (mapl_am_i_root()) then
allocate(lons(nx),lats(nx),_STAT)
else
allocate(lons(0),lats(0),_STAT)
endif

allocate( this%recvcounts(petcount), this%displs(petcount), _STAT )
allocate( recvcounts_loc(petcount), displs_loc(petcount), _STAT )
recvcounts_loc(:)=1
displs_loc(1)=0
do i=2, petcount
displs_loc(i) = displs_loc(i-1) + recvcounts_loc(i-1)
end do
call MPI_gatherv ( nx2, 1, MPI_INTEGER, &
this%recvcounts, recvcounts_loc, displs_loc, MPI_INTEGER,&
iroot, mpic, ierr )
if (.not. mapl_am_i_root()) then
this%recvcounts(:) = 0
end if
this%displs(1)=0
do i=2, petcount
this%displs(i) = this%displs(i-1) + this%recvcounts(i-1)
end do

nsend = nx2
call MPI_gatherv ( lons_chunk, nsend, MPI_REAL8, &
lons, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )
call MPI_gatherv ( lats_chunk, nsend, MPI_REAL8, &
lats, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )


!! if (mapl_am_I_root()) write(6,*) 'nobs tot :', nx

deallocate (this%recvcounts, this%displs, _STAT)
deallocate (recvcounts_loc, displs_loc, _STAT)
deallocate (x, y, _STAT)
call MAPL_TimerOff(this%GENSTATE,"1_genABIgrid")


! __ s2. set distributed LS
!
call MAPL_TimerOn(this%GENSTATE,"2_ABIgrid_LS")

! -- root
locstream_factory = LocStreamFactory(lons,lats,_RC)
LS_rt = locstream_factory%create_locstream(_RC)

! -- proc
locstream_factory = LocStreamFactory(lons_chunk,lats_chunk,_RC)
LS_chunk = locstream_factory%create_locstream_on_proc(_RC)

! -- distributed with background grid
call ESMF_FieldBundleGet(this%bundle,grid=grid,_RC)
LS_ds = locstream_factory%create_locstream(grid=grid,_RC)
LS_ds = locstream_factory%create_locstream_on_proc(grid=grid,_RC)

fieldA = ESMF_FieldCreate (LS_rt, name='A', typekind=ESMF_TYPEKIND_R8, _RC)
fieldA = ESMF_FieldCreate (LS_chunk, name='A', typekind=ESMF_TYPEKIND_R8, _RC)
fieldB = ESMF_FieldCreate (LS_ds, name='B', typekind=ESMF_TYPEKIND_R8, _RC)

call ESMF_FieldGet( fieldA, localDE=0, farrayPtr=ptA)
call ESMF_FieldGet( fieldB, localDE=0, farrayPtr=ptB)
if (mypet == 0) then
ptA(:) = lons(:)
end if

ptA(:) = lons_chunk(:)
call ESMF_FieldRedistStore (fieldA, fieldB, RH, _RC)
call MPI_Barrier(mpic,ierr)
_VERIFY (ierr)
call ESMF_FieldRedist (fieldA, fieldB, RH, _RC)
lons_ds = ptB

if (mypet == 0) then
ptA(:) = lats(:)
end if
ptA(:) = lats_chunk(:)
call MPI_Barrier(mpic,ierr)
_VERIFY (ierr)
call ESMF_FieldRedist (fieldA, fieldB, RH, _RC)
lats_ds = ptB

call ESMF_FieldRedistRelease(RH, noGarbage=.true., _RC)
!! write(6,*) 'ip, size(lons_ds)=', mypet, size(lons_ds)

call ESMF_FieldDestroy(fieldA,nogarbage=.true.,_RC)
call ESMF_FieldDestroy(fieldB,nogarbage=.true.,_RC)
call ESMF_FieldRedistRelease(RH, noGarbage=.true., _RC)

call MAPL_TimerOff(this%GENSTATE,"2_ABIgrid_LS")


! __ s3. find n.n. CS pts for LS_ds (halo)
!
call MAPL_TimerOn(this%GENSTATE,"3_CS_halo")
obs_lons = lons_ds * MAPL_DEGREES_TO_RADIANS_R8
obs_lats = lats_ds * MAPL_DEGREES_TO_RADIANS_R8
nx = size ( lons_ds )
Expand Down Expand Up @@ -407,6 +478,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
end if
end do
end do
call MAPL_TimerOff(this%GENSTATE,"3_CS_halo")


! ----
Expand All @@ -415,6 +487,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
! - mpi_gatherV
!

call MAPL_TimerOn(this%GENSTATE,"4_gatherV")

! __ s4.1 find this%lons/lats on root for NC output
!
Expand Down Expand Up @@ -442,11 +515,11 @@ module subroutine create_Geosat_grid_find_mask(this, rc)

! __ s4.2 find this%recvcounts / this%displs
!
allocate( this%recvcounts(npes), this%displs(npes), _STAT )
allocate( recvcounts_loc(npes), displs_loc(npes), _STAT )
allocate( this%recvcounts(petcount), this%displs(petcount), _STAT )
allocate( recvcounts_loc(petcount), displs_loc(petcount), _STAT )
recvcounts_loc(:)=1
displs_loc(1)=0
do i=2, npes
do i=2, petcount
displs_loc(i) = displs_loc(i-1) + recvcounts_loc(i-1)
end do
call MPI_gatherv ( this%npt_mask, 1, MPI_INTEGER, &
Expand All @@ -456,7 +529,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
this%recvcounts(:) = 0
end if
this%displs(1)=0
do i=2, npes
do i=2, petcount
this%displs(i) = this%displs(i-1) + this%recvcounts(i-1)
end do

Expand All @@ -471,6 +544,8 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
this%lats, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )

call MAPL_TimerOff(this%GENSTATE,"4_gatherV")

_RETURN(_SUCCESS)
end subroutine create_Geosat_grid_find_mask

Expand Down Expand Up @@ -589,7 +664,7 @@ module subroutine regrid_append_file(this,current_time,rc)
integer :: i, j, k, rank
integer :: nx, nz
integer :: ix, iy, m
integer :: mypet, npes, nsend
integer :: mypet, petcount, nsend
integer :: iroot, ierr
integer :: mpic
integer, allocatable :: recvcounts_3d(:)
Expand All @@ -602,7 +677,7 @@ module subroutine regrid_append_file(this,current_time,rc)

! -- fixed for all fields
call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=npes, localpet=mypet, _RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=petcount, localpet=mypet, _RC)
iroot=0
nx = this%npt_mask
nz = this%vdata%lm
Expand All @@ -615,7 +690,7 @@ module subroutine regrid_append_file(this,current_time,rc)
allocate ( p_dst_2d_full (0), _STAT )
allocate ( p_dst_3d_full (0), _STAT )
end if
allocate( recvcounts_3d(npes), displs_3d(npes), _STAT )
allocate( recvcounts_3d(petcount), displs_3d(petcount), _STAT )
recvcounts_3d(:) = nz * this%recvcounts(:)
displs_3d(:) = nz * this%displs(:)

Expand Down