diff --git a/src/drivers/mct/cime_config/namelist_definition_drv.xml b/src/drivers/mct/cime_config/namelist_definition_drv.xml
index 600b56cbac7..ebc5596bf86 100644
--- a/src/drivers/mct/cime_config/namelist_definition_drv.xml
+++ b/src/drivers/mct/cime_config/namelist_definition_drv.xml
@@ -1470,6 +1470,19 @@
+
+ logical
+ reprosum
+ seq_infodata_inparm
+
+ Allow INF and NaN in summands
+ default: .false.
+
+
+ .false.
+
+
+
real
reprosum
diff --git a/src/drivers/mct/main/cime_comp_mod.F90 b/src/drivers/mct/main/cime_comp_mod.F90
index 25dbe028360..de2439aa137 100644
--- a/src/drivers/mct/main/cime_comp_mod.F90
+++ b/src/drivers/mct/main/cime_comp_mod.F90
@@ -417,6 +417,7 @@ module cime_comp_mod
logical :: shr_map_dopole ! logical for dopole in shr_map_mod
logical :: domain_check ! .true. => check consistency of domains
logical :: reprosum_use_ddpdd ! setup reprosum, use ddpdd
+ logical :: reprosum_allow_infnan ! setup reprosum, allow INF and NaN in summands
real(r8) :: reprosum_diffmax ! setup reprosum, set rel_diff_max
logical :: reprosum_recompute ! setup reprosum, recompute if tolerance exceeded
@@ -935,6 +936,7 @@ subroutine cime_pre_init2()
wall_time_limit=wall_time_limit , &
force_stop_at=force_stop_at , &
reprosum_use_ddpdd=reprosum_use_ddpdd , &
+ reprosum_allow_infnan=reprosum_allow_infnan, &
reprosum_diffmax=reprosum_diffmax , &
reprosum_recompute=reprosum_recompute, &
max_cplstep_time=max_cplstep_time)
@@ -946,6 +948,7 @@ subroutine cime_pre_init2()
call shr_reprosum_setopts(&
repro_sum_use_ddpdd_in = reprosum_use_ddpdd, &
+ repro_sum_allow_infnan_in = reprosum_allow_infnan, &
repro_sum_rel_diff_max_in = reprosum_diffmax, &
repro_sum_recompute_in = reprosum_recompute)
diff --git a/src/drivers/mct/main/seq_diag_mct.F90 b/src/drivers/mct/main/seq_diag_mct.F90
index efe851462c1..b1f80d865cc 100644
--- a/src/drivers/mct/main/seq_diag_mct.F90
+++ b/src/drivers/mct/main/seq_diag_mct.F90
@@ -45,6 +45,7 @@ module seq_diag_mct
use component_type_mod, only : COMPONENT_GET_DOM_CX, COMPONENT_GET_C2X_CX, &
COMPONENT_GET_X2C_CX, COMPONENT_TYPE
use seq_infodata_mod, only : seq_infodata_type, seq_infodata_getdata
+ use shr_reprosum_mod, only: shr_reprosum_calc
implicit none
save
@@ -2237,7 +2238,9 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment)
integer(in) :: iam ! pe number
integer(in) :: km,ka ! field indices
integer(in) :: ns ! size of local AV
+ integer(in) :: rcode ! allocate return code
real(r8), pointer :: weight(:) ! weight
+ real(r8), allocatable :: weighted_data(:,:) ! weighted data
type(mct_string) :: mstring ! mct char type
character(CL) :: lcomment ! should be long enough
character(CL) :: itemc ! string converted to char
@@ -2250,11 +2253,8 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment)
! print instantaneous budget data
!-------------------------------------------------------------------------------
- call seq_comm_setptrs(ID,&
- mpicom=mpicom, iam=iam)
-
- call seq_infodata_GetData(infodata,&
- bfbflag=bfbflag)
+ call seq_comm_setptrs(ID, mpicom=mpicom, iam=iam)
+ call seq_infodata_GetData(infodata, bfbflag=bfbflag)
lcomment = ''
if (present(comment)) then
@@ -2267,82 +2267,44 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment)
km = mct_aVect_indexRA(dom%data,'mask')
ka = mct_aVect_indexRA(dom%data,afldname)
kflds = mct_aVect_nRattr(AV)
- allocate(sumbuf(kflds),sumbufg(kflds))
-
- sumbuf = 0.0_r8
-
- if (bfbflag) then
-
- npts = mct_aVect_lsize(AV)
- allocate(weight(npts))
- weight(:) = 1.0_r8
- do n = 1,npts
- if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then
- weight(n) = 0.0_r8
- else
- weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth
- endif
- enddo
+ allocate(sumbufg(kflds),stat=rcode)
+ if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate sumbufg')
- allocate(maxbuf(kflds),maxbufg(kflds))
- maxbuf = 0.0_r8
-
- do n = 1,npts
- do k = 1,kflds
- if (.not. shr_const_isspval(AV%rAttr(k,n))) then
- maxbuf(k) = max(maxbuf(k),abs(AV%rAttr(k,n)*weight(n)))
- endif
- enddo
- enddo
-
- call shr_mpi_max(maxbuf,maxbufg,mpicom,subname,all=.true.)
- call shr_mpi_sum(npts,nptsg,mpicom,subname,all=.true.)
+ npts = mct_aVect_lsize(AV)
+ allocate(weight(npts),stat=rcode)
+ if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate weight')
- do k = 1,kflds
- if (maxbufg(k) < 1000.0*TINY(maxbufg(k)) .or. &
- maxbufg(k) > HUGE(maxbufg(k))/(2.0_r8*nptsg)) then
- maxbufg(k) = 0.0_r8
- else
- maxbufg(k) = (1.1_r8) * maxbufg(k) * nptsg
- endif
- enddo
+ weight(:) = 1.0_r8
+ do n = 1,npts
+ if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then
+ weight(n) = 0.0_r8
+ else
+ weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth
+ endif
+ enddo
- allocate(isumbuf(kflds),isumbufg(kflds))
- isumbuf = 0
- ihuge = HUGE(isumbuf)
+ if (bfbflag) then
+ allocate(weighted_data(npts,kflds),stat=rcode)
+ if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate weighted_data')
+ weighted_data = 0.0_r8
do n = 1,npts
do k = 1,kflds
if (.not. shr_const_isspval(AV%rAttr(k,n))) then
- if (abs(maxbufg(k)) > 1000.0_r8 * TINY(maxbufg)) then
- isumbuf(k) = isumbuf(k) + int((AV%rAttr(k,n)*weight(n)/maxbufg(k))*ihuge,i8)
- endif
+ weighted_data(n,k) = AV%rAttr(k,n)*weight(n)
endif
enddo
enddo
- call shr_mpi_sum(isumbuf,isumbufg,mpicom,subname)
+ call shr_reprosum_calc (weighted_data, sumbufg, npts, npts, kflds, &
+ commid=mpicom)
- do k = 1,kflds
- sumbufg(k) = isumbufg(k)*maxbufg(k)/ihuge
- enddo
-
- deallocate(weight)
- deallocate(maxbuf,maxbufg)
- deallocate(isumbuf,isumbufg)
+ deallocate(weighted_data)
else
-
- npts = mct_aVect_lsize(AV)
- allocate(weight(npts))
- weight(:) = 1.0_r8
- do n = 1,npts
- if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then
- weight(n) = 0.0_r8
- else
- weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth
- endif
- enddo
+ allocate(sumbuf(kflds),stat=rcode)
+ if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate sumbuf')
+ sumbuf = 0.0_r8
do n = 1,npts
do k = 1,kflds
@@ -2355,9 +2317,10 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment)
!--- global reduction ---
call shr_mpi_sum(sumbuf,sumbufg,mpicom,subname)
- deallocate(weight)
+ deallocate(sumbuf)
endif
+ deallocate(weight)
if (iam == 0) then
! write(logunit,*) 'sdAV: *** writing ',trim(lcomment),': k fld min/max/sum ***'
@@ -2374,7 +2337,7 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment)
call shr_sys_flush(logunit)
endif
- deallocate(sumbuf,sumbufg)
+ deallocate(sumbufg)
100 format('comm_diag ',a3,1x,a4,1x,i3,es26.19,1x,a,1x,a)
101 format('comm_diag ',a3,1x,a4,1x,i3,es26.19,1x,a)
diff --git a/src/drivers/mct/shr/seq_infodata_mod.F90 b/src/drivers/mct/shr/seq_infodata_mod.F90
index 82a984c77dd..47ec473c25b 100644
--- a/src/drivers/mct/shr/seq_infodata_mod.F90
+++ b/src/drivers/mct/shr/seq_infodata_mod.F90
@@ -176,6 +176,7 @@ MODULE seq_infodata_mod
logical :: mct_usevector ! flag for mct vector
logical :: reprosum_use_ddpdd ! use ddpdd algorithm
+ logical :: reprosum_allow_infnan ! allow INF and NaN summands
real(SHR_KIND_R8) :: reprosum_diffmax ! maximum difference tolerance
logical :: reprosum_recompute ! recompute reprosum with nonscalable algorithm
! if reprosum_diffmax is exceeded
@@ -412,6 +413,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag)
real(SHR_KIND_R8) :: eps_ogrid ! ocn grid error tolerance
real(SHR_KIND_R8) :: eps_oarea ! ocn area error tolerance
logical :: reprosum_use_ddpdd ! use ddpdd algorithm
+ logical :: reprosum_allow_infnan ! allow INF and NaN summands
real(SHR_KIND_R8) :: reprosum_diffmax ! maximum difference tolerance
logical :: reprosum_recompute ! recompute reprosum with nonscalable algorithm
! if reprosum_diffmax is exceeded
@@ -452,7 +454,8 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag)
eps_frac, eps_amask, &
eps_agrid, eps_aarea, eps_omask, eps_ogrid, &
eps_oarea, esmf_map_flag, &
- reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, &
+ reprosum_use_ddpdd, reprosum_allow_infnan, &
+ reprosum_diffmax, reprosum_recompute, &
mct_usealltoall, mct_usevector, max_cplstep_time, model_doi_url
!-------------------------------------------------------------------------------
@@ -560,6 +563,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag)
eps_ogrid = 1.0e-02_SHR_KIND_R8
eps_oarea = 1.0e-01_SHR_KIND_R8
reprosum_use_ddpdd = .false.
+ reprosum_allow_infnan = .false.
reprosum_diffmax = -1.0e-8
reprosum_recompute = .false.
mct_usealltoall = .false.
@@ -685,6 +689,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag)
infodata%eps_ogrid = eps_ogrid
infodata%eps_oarea = eps_oarea
infodata%reprosum_use_ddpdd = reprosum_use_ddpdd
+ infodata%reprosum_allow_infnan = reprosum_allow_infnan
infodata%reprosum_diffmax = reprosum_diffmax
infodata%reprosum_recompute = reprosum_recompute
infodata%mct_usealltoall = mct_usealltoall
@@ -977,7 +982,8 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_
lnd_nx, lnd_ny, rof_nx, rof_ny, ice_nx, ice_ny, ocn_nx, ocn_ny, &
glc_nx, glc_ny, eps_frac, eps_amask, &
eps_agrid, eps_aarea, eps_omask, eps_ogrid, eps_oarea, &
- reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, &
+ reprosum_use_ddpdd, reprosum_allow_infnan, &
+ reprosum_diffmax, reprosum_recompute, &
atm_resume, lnd_resume, ocn_resume, ice_resume, &
glc_resume, rof_resume, wav_resume, cpl_resume, &
mct_usealltoall, mct_usevector, max_cplstep_time, model_doi_url, &
@@ -1085,6 +1091,7 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_
real(SHR_KIND_R8), optional, intent(OUT) :: eps_ogrid ! ocn grid error tolerance
real(SHR_KIND_R8), optional, intent(OUT) :: eps_oarea ! ocn area error tolerance
logical, optional, intent(OUT) :: reprosum_use_ddpdd ! use ddpdd algorithm
+ logical, optional, intent(OUT) :: reprosum_allow_infnan ! allow INF and NaN summands
real(SHR_KIND_R8), optional, intent(OUT) :: reprosum_diffmax ! maximum difference tolerance
logical, optional, intent(OUT) :: reprosum_recompute ! recompute if tolerance exceeded
logical, optional, intent(OUT) :: mct_usealltoall ! flag for mct alltoall
@@ -1261,6 +1268,7 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_
if ( present(eps_ogrid) ) eps_ogrid = infodata%eps_ogrid
if ( present(eps_oarea) ) eps_oarea = infodata%eps_oarea
if ( present(reprosum_use_ddpdd)) reprosum_use_ddpdd = infodata%reprosum_use_ddpdd
+ if ( present(reprosum_allow_infnan)) reprosum_allow_infnan = infodata%reprosum_allow_infnan
if ( present(reprosum_diffmax) ) reprosum_diffmax = infodata%reprosum_diffmax
if ( present(reprosum_recompute)) reprosum_recompute = infodata%reprosum_recompute
if ( present(mct_usealltoall)) mct_usealltoall = infodata%mct_usealltoall
@@ -1555,7 +1563,8 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_
lnd_nx, lnd_ny, rof_nx, rof_ny, ice_nx, ice_ny, ocn_nx, ocn_ny, &
glc_nx, glc_ny, eps_frac, eps_amask, &
eps_agrid, eps_aarea, eps_omask, eps_ogrid, eps_oarea, &
- reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, &
+ reprosum_use_ddpdd, reprosum_allow_infnan, &
+ reprosum_diffmax, reprosum_recompute, &
atm_resume, lnd_resume, ocn_resume, ice_resume, &
glc_resume, rof_resume, wav_resume, cpl_resume, &
mct_usealltoall, mct_usevector, glc_valid_input)
@@ -1661,6 +1670,7 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_
real(SHR_KIND_R8), optional, intent(IN) :: eps_ogrid ! ocn grid error tolerance
real(SHR_KIND_R8), optional, intent(IN) :: eps_oarea ! ocn area error tolerance
logical, optional, intent(IN) :: reprosum_use_ddpdd ! use ddpdd algorithm
+ logical, optional, intent(IN) :: reprosum_allow_infnan ! allow INF and NaN summands
real(SHR_KIND_R8), optional, intent(IN) :: reprosum_diffmax ! maximum difference tolerance
logical, optional, intent(IN) :: reprosum_recompute ! recompute if tolerance exceeded
logical, optional, intent(IN) :: mct_usealltoall ! flag for mct alltoall
@@ -1835,6 +1845,7 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_
if ( present(eps_ogrid) ) infodata%eps_ogrid = eps_ogrid
if ( present(eps_oarea) ) infodata%eps_oarea = eps_oarea
if ( present(reprosum_use_ddpdd)) infodata%reprosum_use_ddpdd = reprosum_use_ddpdd
+ if ( present(reprosum_allow_infnan)) infodata%reprosum_allow_infnan = reprosum_allow_infnan
if ( present(reprosum_diffmax) ) infodata%reprosum_diffmax = reprosum_diffmax
if ( present(reprosum_recompute)) infodata%reprosum_recompute = reprosum_recompute
if ( present(mct_usealltoall)) infodata%mct_usealltoall = mct_usealltoall
@@ -2257,6 +2268,7 @@ subroutine seq_infodata_bcast(infodata,mpicom)
call shr_mpi_bcast(infodata%eps_ogrid, mpicom)
call shr_mpi_bcast(infodata%eps_oarea, mpicom)
call shr_mpi_bcast(infodata%reprosum_use_ddpdd, mpicom)
+ call shr_mpi_bcast(infodata%reprosum_allow_infnan, mpicom)
call shr_mpi_bcast(infodata%reprosum_diffmax, mpicom)
call shr_mpi_bcast(infodata%reprosum_recompute, mpicom)
call shr_mpi_bcast(infodata%mct_usealltoall, mpicom)
@@ -2931,6 +2943,7 @@ SUBROUTINE seq_infodata_print( infodata )
write(logunit,F0R) subname,'eps_oarea = ', infodata%eps_oarea
write(logunit,F0L) subname,'reprosum_use_ddpdd = ', infodata%reprosum_use_ddpdd
+ write(logunit,F0L) subname,'reprosum_allow_infnan = ', infodata%reprosum_allow_infnan
write(logunit,F0R) subname,'reprosum_diffmax = ', infodata%reprosum_diffmax
write(logunit,F0L) subname,'reprosum_recompute = ', infodata%reprosum_recompute
diff --git a/src/share/util/shr_reprosum_mod.F90 b/src/share/util/shr_reprosum_mod.F90
index 9acfa54813c..a8ef29c1b15 100644
--- a/src/share/util/shr_reprosum_mod.F90
+++ b/src/share/util/shr_reprosum_mod.F90
@@ -38,7 +38,11 @@ module shr_reprosum_mod
use shr_log_mod, only: s_loglev => shr_log_Level
use shr_log_mod, only: s_logunit => shr_log_Unit
use shr_sys_mod, only: shr_sys_abort
- use shr_infnan_mod,only: shr_infnan_isnan, shr_infnan_isinf
+ use shr_infnan_mod,only: shr_infnan_inf_type, assignment(=), &
+ shr_infnan_posinf, shr_infnan_neginf, &
+ shr_infnan_nan, &
+ shr_infnan_isnan, shr_infnan_isinf, &
+ shr_infnan_isposinf, shr_infnan_isneginf
use perf_mod
!-----------------------------------------------------------------------
@@ -86,12 +90,15 @@ module shr_reprosum_mod
!----------------------------------------------------------------------------
logical :: repro_sum_use_ddpdd = .false.
+ logical :: repro_sum_allow_infnan = .false.
+
CONTAINS
!
!========================================================================
!
subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, &
+ repro_sum_allow_infnan_in, &
repro_sum_rel_diff_max_in, &
repro_sum_recompute_in, &
repro_sum_master, &
@@ -104,6 +111,8 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, &
!------------------------------Arguments--------------------------------
! Use DDPDD algorithm instead of fixed precision algorithm
logical, intent(in), optional :: repro_sum_use_ddpdd_in
+ ! Allow INF or NaN in summands
+ logical, intent(in), optional :: repro_sum_allow_infnan_in
! maximum permissible difference between reproducible and
! nonreproducible sums
real(r8), intent(in), optional :: repro_sum_rel_diff_max_in
@@ -142,6 +151,9 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, &
if ( present(repro_sum_use_ddpdd_in) ) then
repro_sum_use_ddpdd = repro_sum_use_ddpdd_in
endif
+ if ( present(repro_sum_allow_infnan_in) ) then
+ repro_sum_allow_infnan = repro_sum_allow_infnan_in
+ endif
if ( present(repro_sum_rel_diff_max_in) ) then
shr_reprosum_reldiffmax = repro_sum_rel_diff_max_in
endif
@@ -159,6 +171,14 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, &
'distributed sum algorithm'
endif
+ if ( repro_sum_allow_infnan ) then
+ write(logunit,*) 'SHR_REPROSUM_SETOPTS: ',&
+ 'Will calculate sum when INF or NaN are included in summands'
+ else
+ write(logunit,*) 'SHR_REPROSUM_SETOPTS: ',&
+ 'Will abort if INF or NaN are included in summands'
+ endif
+
if (shr_reprosum_reldiffmax >= 0._r8) then
write(logunit,*) ' ',&
'with a maximum relative error tolerance of ', &
@@ -185,7 +205,7 @@ end subroutine shr_reprosum_setopts
!
subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
- nflds, ddpdd_sum, &
+ nflds, allow_infnan, ddpdd_sum, &
arr_gbl_max, arr_gbl_max_out, &
arr_max_levels, arr_max_levels_out, &
gbl_max_nsummands, gbl_max_nsummands_out,&
@@ -280,6 +300,10 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
! use ddpdd algorithm instead
! of fixed precision algorithm
+ logical, intent(in), optional :: allow_infnan
+ ! if .true., allow INF or NaN input values.
+ ! if .false. (the default), then abort.
+
real(r8), intent(in), optional :: arr_gbl_max(nflds)
! upper bound on max(abs(arr))
@@ -312,13 +336,14 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
! flag enabling/disabling testing that gmax and max_levels are
! accurate/sufficient. Default is enabled.
- integer, intent(inout), optional :: repro_sum_stats(5)
+ integer, intent(inout), optional :: repro_sum_stats(6)
! increment running totals for
! (1) one-reduction repro_sum
! (2) two-reduction repro_sum
! (3) both types in one call
! (4) nonrepro_sum
! (5) global max nsummands reduction
+ ! (6) global lor 3*nflds reduction
real(r8), intent(out), optional :: rel_diff(2,nflds)
! relative and absolute
@@ -331,6 +356,8 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
!
! Local workspace
!
+ logical :: abort_inf_nan ! flag indicating whether to
+ ! abort if INF or NaN found in input
logical :: use_ddpdd_sum ! flag indicating whether to
! use shr_reprosum_ddpdd or not
logical :: recompute ! flag indicating need to
@@ -341,8 +368,23 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
! are accurate/sufficient
logical :: nan_check, inf_check ! flag on whether there are
! NaNs and INFs in input array
+ logical :: inf_nan_lchecks(3,nflds)! flags on whether there are
+ ! NaNs, positive INFs, or negative INFs
+ ! for each input field locally
+ logical :: inf_nan_gchecks(3,nflds)! flags on whether there are
+ ! NaNs, positive INFs, or negative INFs
+ ! for each input field
+ logical :: arr_gsum_infnan(nflds) ! flag on whether field sum is a
+ ! NaN or INF
+
+ integer :: gbl_lor_red ! global lor reduction? (0/1)
+ integer :: gbl_max_red ! global max reduction? (0/1)
+ integer :: repro_sum_fast ! 1 reduction repro_sum? (0/1)
+ integer :: repro_sum_slow ! 2 reduction repro_sum? (0/1)
+ integer :: repro_sum_both ! both fast and slow? (0/1)
+ integer :: nonrepro_sum ! nonrepro_sum? (0/1)
- integer :: num_nans, num_infs ! count of NaNs and INFs in
+ integer :: nan_count, inf_count ! local count of NaNs and INFs in
! input array
integer :: omp_nthreads ! number of OpenMP threads
integer :: mpi_comm ! MPI subcommunicator
@@ -375,11 +417,6 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
integer :: max_levels(nflds) ! maximum number of levels of
! integer expansion to use
integer :: max_level ! maximum value in max_levels
- integer :: gbl_max_red ! global max local sum reduction? (0/1)
- integer :: repro_sum_fast ! 1 reduction repro_sum? (0/1)
- integer :: repro_sum_slow ! 2 reduction repro_sum? (0/1)
- integer :: repro_sum_both ! both fast and slow? (0/1)
- integer :: nonrepro_sum ! nonrepro_sum? (0/1)
real(r8) :: xmax_nsummands ! dble of max_nsummands
real(r8) :: arr_lsum(nflds) ! local sums
@@ -396,38 +433,81 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
!
!-----------------------------------------------------------------------
!
-! check whether input contains NaNs or INFs, and abort if so
+! initialize local statistics variables
+ gbl_lor_red = 0
+ gbl_max_red = 0
+ repro_sum_fast = 0
+ repro_sum_slow = 0
+ repro_sum_both = 0
+ nonrepro_sum = 0
- call t_startf('shr_reprosum_NaN_INF_Chk')
- nan_check = .false.
- inf_check = .false.
- num_nans = 0
- num_infs = 0
+! set MPI communicator
+ if ( present(commid) ) then
+ mpi_comm = commid
+ else
+ mpi_comm = MPI_COMM_WORLD
+ endif
+ call t_barrierf('sync_repro_sum',mpi_comm)
- nan_check = any(shr_infnan_isnan(arr))
- inf_check = any(shr_infnan_isinf(arr))
- if (nan_check .or. inf_check) then
- do ifld=1,nflds
- do isum=1,nsummands
- if (shr_infnan_isnan(arr(isum,ifld))) then
- num_nans = num_nans + 1
- endif
- if (shr_infnan_isinf(arr(isum,ifld))) then
- num_infs = num_infs + 1
- endif
- end do
- end do
+! check whether should abort if input contains NaNs or INFs
+ abort_inf_nan = .not. repro_sum_allow_infnan
+ if ( present(allow_infnan) ) then
+ abort_inf_nan = .not. allow_infnan
endif
- call t_stopf('shr_reprosum_NaN_INF_Chk')
- if ((num_nans > 0) .or. (num_infs > 0)) then
- call mpi_comm_rank(MPI_COMM_WORLD, mypid, ierr)
- write(s_logunit,37) real(num_nans,r8), real(num_infs,r8), mypid
+ call t_startf('shr_reprosum_INF_NaN_Chk')
+
+! initialize flags to indicate that no NaNs or INFs are present in the input data
+ inf_nan_gchecks = .false.
+ arr_gsum_infnan = .false.
+
+ if (abort_inf_nan) then
+
+! check whether input contains NaNs or INFs, and abort if so
+ nan_check = any(shr_infnan_isnan(arr))
+ inf_check = any(shr_infnan_isinf(arr))
+
+ if (nan_check .or. inf_check) then
+
+ nan_count = count(shr_infnan_isnan(arr))
+ inf_count = count(shr_infnan_isinf(arr))
+
+ if ((nan_count > 0) .or. (inf_count > 0)) then
+ call mpi_comm_rank(MPI_COMM_WORLD, mypid, ierr)
+ write(s_logunit,37) real(nan_count,r8), real(inf_count,r8), mypid
37 format("SHR_REPROSUM_CALC: Input contains ",e12.5, &
" NaNs and ", e12.5, " INFs on process ", i7)
- call shr_sys_abort("shr_reprosum_calc ERROR: NaNs or INFs in input")
+ call shr_sys_abort("shr_reprosum_calc ERROR: NaNs or INFs in input")
+ endif
+
+ endif
+
+ else
+
+! determine whether any fields contain NaNs or INFs, and avoid processing them
+! via integer expansions
+ inf_nan_lchecks = .false.
+
+ do ifld=1,nflds
+ inf_nan_lchecks(1,ifld) = any(shr_infnan_isnan(arr(:,ifld)))
+ inf_nan_lchecks(2,ifld) = any(shr_infnan_isposinf(arr(:,ifld)))
+ inf_nan_lchecks(3,ifld) = any(shr_infnan_isneginf(arr(:,ifld)))
+ end do
+
+ call t_startf("repro_sum_allr_lor")
+ call mpi_allreduce (inf_nan_lchecks, inf_nan_gchecks, 3*nflds, &
+ MPI_LOGICAL, MPI_LOR, mpi_comm, ierr)
+ gbl_lor_red = 1
+ call t_stopf("repro_sum_allr_lor")
+
+ do ifld=1,nflds
+ arr_gsum_infnan(ifld) = any(inf_nan_gchecks(:,ifld))
+ enddo
+
endif
+ call t_stopf('shr_reprosum_INF_NaN_Chk')
+
! check whether should use shr_reprosum_ddpdd algorithm
use_ddpdd_sum = repro_sum_use_ddpdd
if ( present(ddpdd_sum) ) then
@@ -439,21 +519,6 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
! If not, always use ddpdd.
use_ddpdd_sum = use_ddpdd_sum .or. (radix(0._r8) /= radix(0_i8))
-! initialize local statistics variables
- gbl_max_red = 0
- repro_sum_fast = 0
- repro_sum_slow = 0
- repro_sum_both = 0
- nonrepro_sum = 0
-
-! set MPI communicator
- if ( present(commid) ) then
- mpi_comm = commid
- else
- mpi_comm = MPI_COMM_WORLD
- endif
- call t_barrierf('sync_repro_sum',mpi_comm)
-
if ( use_ddpdd_sum ) then
call t_startf('shr_reprosum_ddpdd')
@@ -548,8 +613,8 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
endif
call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, &
nflds, arr_max_shift, arr_gmax_exp, &
- arr_max_levels, max_level, validate, &
- recompute, omp_nthreads, mpi_comm)
+ arr_max_levels, max_level, arr_gsum_infnan, &
+ validate, recompute, omp_nthreads, mpi_comm)
! record statistics, etc.
repro_sum_fast = 1
@@ -598,13 +663,15 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
do ifld=1,nflds
arr_exp_tlmin = MAXEXPONENT(1._r8)
arr_exp_tlmax = MINEXPONENT(1._r8)
- do isum=isum_beg(ithread),isum_end(ithread)
- if (arr(isum,ifld) .ne. 0.0_r8) then
- arr_exp = exponent(arr(isum,ifld))
- arr_exp_tlmin = min(arr_exp,arr_exp_tlmin)
- arr_exp_tlmax = max(arr_exp,arr_exp_tlmax)
- endif
- end do
+ if (.not. arr_gsum_infnan(ifld)) then
+ do isum=isum_beg(ithread),isum_end(ithread)
+ if (arr(isum,ifld) .ne. 0.0_r8) then
+ arr_exp = exponent(arr(isum,ifld))
+ arr_exp_tlmin = min(arr_exp,arr_exp_tlmin)
+ arr_exp_tlmax = max(arr_exp,arr_exp_tlmax)
+ endif
+ end do
+ endif
arr_tlmin_exp(ifld,ithread) = arr_exp_tlmin
arr_tlmax_exp(ifld,ithread) = arr_exp_tlmax
end do
@@ -628,9 +695,9 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
arr_gmax_exp(:) = -arr_gextremes(1:nflds,1)
arr_gmin_exp(:) = arr_gextremes(1:nflds,2)
-! if a field is identically zero, arr_gmin_exp still equals MAXEXPONENT
-! and arr_gmax_exp still equals MINEXPONENT. In this case, set
-! arr_gmin_exp = arr_gmax_exp = MINEXPONENT
+! if a field is identically zero or contains INFs or NaNs, arr_gmin_exp
+! still equals MAXEXPONENT and arr_gmax_exp still equals MINEXPONENT.
+! In this case, set arr_gmin_exp = arr_gmax_exp = MINEXPONENT
do ifld=1,nflds
arr_gmin_exp(ifld) = min(arr_gmax_exp(ifld),arr_gmin_exp(ifld))
enddo
@@ -695,10 +762,10 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
! calculate sum
validate = .false.
- call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, nflds, &
- arr_max_shift, arr_gmax_exp, max_levels, &
- max_level, validate, recompute, &
- omp_nthreads, mpi_comm)
+ call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, &
+ nflds, arr_max_shift, arr_gmax_exp, &
+ max_levels, max_level, arr_gsum_infnan, &
+ validate, recompute, omp_nthreads, mpi_comm)
endif
@@ -720,13 +787,17 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
!$omp default(shared) &
!$omp private(ifld, isum)
do ifld=1,nflds
- do isum=1,nsummands
- arr_lsum(ifld) = arr(isum,ifld) + arr_lsum(ifld)
- end do
+ if (.not. arr_gsum_infnan(ifld)) then
+ do isum=1,nsummands
+ arr_lsum(ifld) = arr(isum,ifld) + arr_lsum(ifld)
+ end do
+ endif
end do
+ call t_startf("nonrepro_sum_allr_r8")
call mpi_allreduce (arr_lsum, arr_gsum_fast, nflds, &
MPI_REAL8, MPI_SUM, mpi_comm, ierr)
+ call t_stopf("nonrepro_sum_allr_r8")
call t_stopf('nonrepro_sum')
@@ -748,6 +819,25 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
endif
endif
+! Set field sums to NaN and INF, as needed
+ do ifld=1,nflds
+ if (arr_gsum_infnan(ifld)) then
+ if (inf_nan_gchecks(1,ifld)) then
+ ! NaN => NaN
+ arr_gsum(ifld) = shr_infnan_nan
+ else if (inf_nan_gchecks(2,ifld) .and. inf_nan_gchecks(3,ifld)) then
+ ! posINF and negINF => NaN
+ arr_gsum(ifld) = shr_infnan_nan
+ else if (inf_nan_gchecks(2,ifld)) then
+ ! posINF only => posINF
+ arr_gsum(ifld) = shr_infnan_posinf
+ else if (inf_nan_gchecks(3,ifld)) then
+ ! negINF only => negINF
+ arr_gsum(ifld) = shr_infnan_neginf
+ endif
+ endif
+ end do
+
! return statistics
if ( present(repro_sum_stats) ) then
repro_sum_stats(1) = repro_sum_stats(1) + repro_sum_fast
@@ -755,6 +845,7 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, &
repro_sum_stats(3) = repro_sum_stats(3) + repro_sum_both
repro_sum_stats(4) = repro_sum_stats(4) + nonrepro_sum
repro_sum_stats(5) = repro_sum_stats(5) + gbl_max_red
+ repro_sum_stats(6) = repro_sum_stats(6) + gbl_lor_red
endif
@@ -766,7 +857,7 @@ end subroutine shr_reprosum_calc
subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, &
arr_max_shift, arr_gmax_exp, max_levels, &
- max_level, validate, recompute, &
+ max_level, skip_field, validate, recompute, &
omp_nthreads, mpi_comm )
!----------------------------------------------------------------------
!
@@ -798,9 +889,14 @@ subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, &
integer, intent(in) :: mpi_comm ! MPI subcommunicator
real(r8), intent(in) :: arr(dsummands,nflds)
- ! input array
+ ! input array
- logical, intent(in):: validate
+ logical, intent(in) :: skip_field(nflds)
+ ! flag indicating whether the sum for this field should be
+ ! computed or not (used to skip over fields containing
+ ! NaN or INF summands)
+
+ logical, intent(in) :: validate
! flag indicating that accuracy of solution generated from
! arr_gmax_exp and max_levels should be tested
@@ -920,8 +1016,10 @@ subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, &
max_error(ifld,ithread) = 0
not_exact(ifld,ithread) = 0
-
i8_arr_tlsum_level(:,ifld,ithread) = 0_i8
+
+ if (skip_field(ifld)) cycle
+
do isum=isum_beg(ithread),isum_end(ithread)
arr_remainder = 0.0_r8
@@ -1370,8 +1468,11 @@ subroutine shr_reprosum_ddpdd (arr, arr_gsum, nsummands, dsummands, &
enddo
+ call t_startf("repro_sum_allr_c16")
call mpi_allreduce (arr_lsum_dd, arr_gsum_dd, nflds, &
MPI_COMPLEX16, mpi_sumdd, mpi_comm, ierr)
+ call t_stopf("repro_sum_allr_c16")
+
do ifld=1,nflds
arr_gsum(ifld) = real(arr_gsum_dd(ifld))
enddo