diff --git a/src/drivers/mct/cime_config/namelist_definition_drv.xml b/src/drivers/mct/cime_config/namelist_definition_drv.xml index 600b56cbac7..ebc5596bf86 100644 --- a/src/drivers/mct/cime_config/namelist_definition_drv.xml +++ b/src/drivers/mct/cime_config/namelist_definition_drv.xml @@ -1470,6 +1470,19 @@ + + logical + reprosum + seq_infodata_inparm + + Allow INF and NaN in summands + default: .false. + + + .false. + + + real reprosum diff --git a/src/drivers/mct/main/cime_comp_mod.F90 b/src/drivers/mct/main/cime_comp_mod.F90 index 25dbe028360..de2439aa137 100644 --- a/src/drivers/mct/main/cime_comp_mod.F90 +++ b/src/drivers/mct/main/cime_comp_mod.F90 @@ -417,6 +417,7 @@ module cime_comp_mod logical :: shr_map_dopole ! logical for dopole in shr_map_mod logical :: domain_check ! .true. => check consistency of domains logical :: reprosum_use_ddpdd ! setup reprosum, use ddpdd + logical :: reprosum_allow_infnan ! setup reprosum, allow INF and NaN in summands real(r8) :: reprosum_diffmax ! setup reprosum, set rel_diff_max logical :: reprosum_recompute ! setup reprosum, recompute if tolerance exceeded @@ -935,6 +936,7 @@ subroutine cime_pre_init2() wall_time_limit=wall_time_limit , & force_stop_at=force_stop_at , & reprosum_use_ddpdd=reprosum_use_ddpdd , & + reprosum_allow_infnan=reprosum_allow_infnan, & reprosum_diffmax=reprosum_diffmax , & reprosum_recompute=reprosum_recompute, & max_cplstep_time=max_cplstep_time) @@ -946,6 +948,7 @@ subroutine cime_pre_init2() call shr_reprosum_setopts(& repro_sum_use_ddpdd_in = reprosum_use_ddpdd, & + repro_sum_allow_infnan_in = reprosum_allow_infnan, & repro_sum_rel_diff_max_in = reprosum_diffmax, & repro_sum_recompute_in = reprosum_recompute) diff --git a/src/drivers/mct/main/seq_diag_mct.F90 b/src/drivers/mct/main/seq_diag_mct.F90 index efe851462c1..b1f80d865cc 100644 --- a/src/drivers/mct/main/seq_diag_mct.F90 +++ b/src/drivers/mct/main/seq_diag_mct.F90 @@ -45,6 +45,7 @@ module seq_diag_mct use component_type_mod, only : COMPONENT_GET_DOM_CX, COMPONENT_GET_C2X_CX, & COMPONENT_GET_X2C_CX, COMPONENT_TYPE use seq_infodata_mod, only : seq_infodata_type, seq_infodata_getdata + use shr_reprosum_mod, only: shr_reprosum_calc implicit none save @@ -2237,7 +2238,9 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment) integer(in) :: iam ! pe number integer(in) :: km,ka ! field indices integer(in) :: ns ! size of local AV + integer(in) :: rcode ! allocate return code real(r8), pointer :: weight(:) ! weight + real(r8), allocatable :: weighted_data(:,:) ! weighted data type(mct_string) :: mstring ! mct char type character(CL) :: lcomment ! should be long enough character(CL) :: itemc ! string converted to char @@ -2250,11 +2253,8 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment) ! print instantaneous budget data !------------------------------------------------------------------------------- - call seq_comm_setptrs(ID,& - mpicom=mpicom, iam=iam) - - call seq_infodata_GetData(infodata,& - bfbflag=bfbflag) + call seq_comm_setptrs(ID, mpicom=mpicom, iam=iam) + call seq_infodata_GetData(infodata, bfbflag=bfbflag) lcomment = '' if (present(comment)) then @@ -2267,82 +2267,44 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment) km = mct_aVect_indexRA(dom%data,'mask') ka = mct_aVect_indexRA(dom%data,afldname) kflds = mct_aVect_nRattr(AV) - allocate(sumbuf(kflds),sumbufg(kflds)) - - sumbuf = 0.0_r8 - - if (bfbflag) then - - npts = mct_aVect_lsize(AV) - allocate(weight(npts)) - weight(:) = 1.0_r8 - do n = 1,npts - if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then - weight(n) = 0.0_r8 - else - weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth - endif - enddo + allocate(sumbufg(kflds),stat=rcode) + if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate sumbufg') - allocate(maxbuf(kflds),maxbufg(kflds)) - maxbuf = 0.0_r8 - - do n = 1,npts - do k = 1,kflds - if (.not. shr_const_isspval(AV%rAttr(k,n))) then - maxbuf(k) = max(maxbuf(k),abs(AV%rAttr(k,n)*weight(n))) - endif - enddo - enddo - - call shr_mpi_max(maxbuf,maxbufg,mpicom,subname,all=.true.) - call shr_mpi_sum(npts,nptsg,mpicom,subname,all=.true.) + npts = mct_aVect_lsize(AV) + allocate(weight(npts),stat=rcode) + if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate weight') - do k = 1,kflds - if (maxbufg(k) < 1000.0*TINY(maxbufg(k)) .or. & - maxbufg(k) > HUGE(maxbufg(k))/(2.0_r8*nptsg)) then - maxbufg(k) = 0.0_r8 - else - maxbufg(k) = (1.1_r8) * maxbufg(k) * nptsg - endif - enddo + weight(:) = 1.0_r8 + do n = 1,npts + if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then + weight(n) = 0.0_r8 + else + weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth + endif + enddo - allocate(isumbuf(kflds),isumbufg(kflds)) - isumbuf = 0 - ihuge = HUGE(isumbuf) + if (bfbflag) then + allocate(weighted_data(npts,kflds),stat=rcode) + if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate weighted_data') + weighted_data = 0.0_r8 do n = 1,npts do k = 1,kflds if (.not. shr_const_isspval(AV%rAttr(k,n))) then - if (abs(maxbufg(k)) > 1000.0_r8 * TINY(maxbufg)) then - isumbuf(k) = isumbuf(k) + int((AV%rAttr(k,n)*weight(n)/maxbufg(k))*ihuge,i8) - endif + weighted_data(n,k) = AV%rAttr(k,n)*weight(n) endif enddo enddo - call shr_mpi_sum(isumbuf,isumbufg,mpicom,subname) + call shr_reprosum_calc (weighted_data, sumbufg, npts, npts, kflds, & + commid=mpicom) - do k = 1,kflds - sumbufg(k) = isumbufg(k)*maxbufg(k)/ihuge - enddo - - deallocate(weight) - deallocate(maxbuf,maxbufg) - deallocate(isumbuf,isumbufg) + deallocate(weighted_data) else - - npts = mct_aVect_lsize(AV) - allocate(weight(npts)) - weight(:) = 1.0_r8 - do n = 1,npts - if (dom%data%rAttr(km,n) <= 1.0e-06_R8) then - weight(n) = 0.0_r8 - else - weight(n) = dom%data%rAttr(ka,n)*shr_const_rearth*shr_const_rearth - endif - enddo + allocate(sumbuf(kflds),stat=rcode) + if (rcode /= 0) call shr_sys_abort(trim(subname)//' allocate sumbuf') + sumbuf = 0.0_r8 do n = 1,npts do k = 1,kflds @@ -2355,9 +2317,10 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment) !--- global reduction --- call shr_mpi_sum(sumbuf,sumbufg,mpicom,subname) - deallocate(weight) + deallocate(sumbuf) endif + deallocate(weight) if (iam == 0) then ! write(logunit,*) 'sdAV: *** writing ',trim(lcomment),': k fld min/max/sum ***' @@ -2374,7 +2337,7 @@ SUBROUTINE seq_diag_avect_mct(infodata, id, av, dom, gsmap, comment) call shr_sys_flush(logunit) endif - deallocate(sumbuf,sumbufg) + deallocate(sumbufg) 100 format('comm_diag ',a3,1x,a4,1x,i3,es26.19,1x,a,1x,a) 101 format('comm_diag ',a3,1x,a4,1x,i3,es26.19,1x,a) diff --git a/src/drivers/mct/shr/seq_infodata_mod.F90 b/src/drivers/mct/shr/seq_infodata_mod.F90 index 82a984c77dd..47ec473c25b 100644 --- a/src/drivers/mct/shr/seq_infodata_mod.F90 +++ b/src/drivers/mct/shr/seq_infodata_mod.F90 @@ -176,6 +176,7 @@ MODULE seq_infodata_mod logical :: mct_usevector ! flag for mct vector logical :: reprosum_use_ddpdd ! use ddpdd algorithm + logical :: reprosum_allow_infnan ! allow INF and NaN summands real(SHR_KIND_R8) :: reprosum_diffmax ! maximum difference tolerance logical :: reprosum_recompute ! recompute reprosum with nonscalable algorithm ! if reprosum_diffmax is exceeded @@ -412,6 +413,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag) real(SHR_KIND_R8) :: eps_ogrid ! ocn grid error tolerance real(SHR_KIND_R8) :: eps_oarea ! ocn area error tolerance logical :: reprosum_use_ddpdd ! use ddpdd algorithm + logical :: reprosum_allow_infnan ! allow INF and NaN summands real(SHR_KIND_R8) :: reprosum_diffmax ! maximum difference tolerance logical :: reprosum_recompute ! recompute reprosum with nonscalable algorithm ! if reprosum_diffmax is exceeded @@ -452,7 +454,8 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag) eps_frac, eps_amask, & eps_agrid, eps_aarea, eps_omask, eps_ogrid, & eps_oarea, esmf_map_flag, & - reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, & + reprosum_use_ddpdd, reprosum_allow_infnan, & + reprosum_diffmax, reprosum_recompute, & mct_usealltoall, mct_usevector, max_cplstep_time, model_doi_url !------------------------------------------------------------------------------- @@ -560,6 +563,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag) eps_ogrid = 1.0e-02_SHR_KIND_R8 eps_oarea = 1.0e-01_SHR_KIND_R8 reprosum_use_ddpdd = .false. + reprosum_allow_infnan = .false. reprosum_diffmax = -1.0e-8 reprosum_recompute = .false. mct_usealltoall = .false. @@ -685,6 +689,7 @@ SUBROUTINE seq_infodata_Init( infodata, nmlfile, ID, pioid, cpl_tag) infodata%eps_ogrid = eps_ogrid infodata%eps_oarea = eps_oarea infodata%reprosum_use_ddpdd = reprosum_use_ddpdd + infodata%reprosum_allow_infnan = reprosum_allow_infnan infodata%reprosum_diffmax = reprosum_diffmax infodata%reprosum_recompute = reprosum_recompute infodata%mct_usealltoall = mct_usealltoall @@ -977,7 +982,8 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_ lnd_nx, lnd_ny, rof_nx, rof_ny, ice_nx, ice_ny, ocn_nx, ocn_ny, & glc_nx, glc_ny, eps_frac, eps_amask, & eps_agrid, eps_aarea, eps_omask, eps_ogrid, eps_oarea, & - reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, & + reprosum_use_ddpdd, reprosum_allow_infnan, & + reprosum_diffmax, reprosum_recompute, & atm_resume, lnd_resume, ocn_resume, ice_resume, & glc_resume, rof_resume, wav_resume, cpl_resume, & mct_usealltoall, mct_usevector, max_cplstep_time, model_doi_url, & @@ -1085,6 +1091,7 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_ real(SHR_KIND_R8), optional, intent(OUT) :: eps_ogrid ! ocn grid error tolerance real(SHR_KIND_R8), optional, intent(OUT) :: eps_oarea ! ocn area error tolerance logical, optional, intent(OUT) :: reprosum_use_ddpdd ! use ddpdd algorithm + logical, optional, intent(OUT) :: reprosum_allow_infnan ! allow INF and NaN summands real(SHR_KIND_R8), optional, intent(OUT) :: reprosum_diffmax ! maximum difference tolerance logical, optional, intent(OUT) :: reprosum_recompute ! recompute if tolerance exceeded logical, optional, intent(OUT) :: mct_usealltoall ! flag for mct alltoall @@ -1261,6 +1268,7 @@ SUBROUTINE seq_infodata_GetData_explicit( infodata, cime_model, case_name, case_ if ( present(eps_ogrid) ) eps_ogrid = infodata%eps_ogrid if ( present(eps_oarea) ) eps_oarea = infodata%eps_oarea if ( present(reprosum_use_ddpdd)) reprosum_use_ddpdd = infodata%reprosum_use_ddpdd + if ( present(reprosum_allow_infnan)) reprosum_allow_infnan = infodata%reprosum_allow_infnan if ( present(reprosum_diffmax) ) reprosum_diffmax = infodata%reprosum_diffmax if ( present(reprosum_recompute)) reprosum_recompute = infodata%reprosum_recompute if ( present(mct_usealltoall)) mct_usealltoall = infodata%mct_usealltoall @@ -1555,7 +1563,8 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_ lnd_nx, lnd_ny, rof_nx, rof_ny, ice_nx, ice_ny, ocn_nx, ocn_ny, & glc_nx, glc_ny, eps_frac, eps_amask, & eps_agrid, eps_aarea, eps_omask, eps_ogrid, eps_oarea, & - reprosum_use_ddpdd, reprosum_diffmax, reprosum_recompute, & + reprosum_use_ddpdd, reprosum_allow_infnan, & + reprosum_diffmax, reprosum_recompute, & atm_resume, lnd_resume, ocn_resume, ice_resume, & glc_resume, rof_resume, wav_resume, cpl_resume, & mct_usealltoall, mct_usevector, glc_valid_input) @@ -1661,6 +1670,7 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_ real(SHR_KIND_R8), optional, intent(IN) :: eps_ogrid ! ocn grid error tolerance real(SHR_KIND_R8), optional, intent(IN) :: eps_oarea ! ocn area error tolerance logical, optional, intent(IN) :: reprosum_use_ddpdd ! use ddpdd algorithm + logical, optional, intent(IN) :: reprosum_allow_infnan ! allow INF and NaN summands real(SHR_KIND_R8), optional, intent(IN) :: reprosum_diffmax ! maximum difference tolerance logical, optional, intent(IN) :: reprosum_recompute ! recompute if tolerance exceeded logical, optional, intent(IN) :: mct_usealltoall ! flag for mct alltoall @@ -1835,6 +1845,7 @@ SUBROUTINE seq_infodata_PutData_explicit( infodata, cime_model, case_name, case_ if ( present(eps_ogrid) ) infodata%eps_ogrid = eps_ogrid if ( present(eps_oarea) ) infodata%eps_oarea = eps_oarea if ( present(reprosum_use_ddpdd)) infodata%reprosum_use_ddpdd = reprosum_use_ddpdd + if ( present(reprosum_allow_infnan)) infodata%reprosum_allow_infnan = reprosum_allow_infnan if ( present(reprosum_diffmax) ) infodata%reprosum_diffmax = reprosum_diffmax if ( present(reprosum_recompute)) infodata%reprosum_recompute = reprosum_recompute if ( present(mct_usealltoall)) infodata%mct_usealltoall = mct_usealltoall @@ -2257,6 +2268,7 @@ subroutine seq_infodata_bcast(infodata,mpicom) call shr_mpi_bcast(infodata%eps_ogrid, mpicom) call shr_mpi_bcast(infodata%eps_oarea, mpicom) call shr_mpi_bcast(infodata%reprosum_use_ddpdd, mpicom) + call shr_mpi_bcast(infodata%reprosum_allow_infnan, mpicom) call shr_mpi_bcast(infodata%reprosum_diffmax, mpicom) call shr_mpi_bcast(infodata%reprosum_recompute, mpicom) call shr_mpi_bcast(infodata%mct_usealltoall, mpicom) @@ -2931,6 +2943,7 @@ SUBROUTINE seq_infodata_print( infodata ) write(logunit,F0R) subname,'eps_oarea = ', infodata%eps_oarea write(logunit,F0L) subname,'reprosum_use_ddpdd = ', infodata%reprosum_use_ddpdd + write(logunit,F0L) subname,'reprosum_allow_infnan = ', infodata%reprosum_allow_infnan write(logunit,F0R) subname,'reprosum_diffmax = ', infodata%reprosum_diffmax write(logunit,F0L) subname,'reprosum_recompute = ', infodata%reprosum_recompute diff --git a/src/share/util/shr_reprosum_mod.F90 b/src/share/util/shr_reprosum_mod.F90 index 9acfa54813c..a8ef29c1b15 100644 --- a/src/share/util/shr_reprosum_mod.F90 +++ b/src/share/util/shr_reprosum_mod.F90 @@ -38,7 +38,11 @@ module shr_reprosum_mod use shr_log_mod, only: s_loglev => shr_log_Level use shr_log_mod, only: s_logunit => shr_log_Unit use shr_sys_mod, only: shr_sys_abort - use shr_infnan_mod,only: shr_infnan_isnan, shr_infnan_isinf + use shr_infnan_mod,only: shr_infnan_inf_type, assignment(=), & + shr_infnan_posinf, shr_infnan_neginf, & + shr_infnan_nan, & + shr_infnan_isnan, shr_infnan_isinf, & + shr_infnan_isposinf, shr_infnan_isneginf use perf_mod !----------------------------------------------------------------------- @@ -86,12 +90,15 @@ module shr_reprosum_mod !---------------------------------------------------------------------------- logical :: repro_sum_use_ddpdd = .false. + logical :: repro_sum_allow_infnan = .false. + CONTAINS ! !======================================================================== ! subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, & + repro_sum_allow_infnan_in, & repro_sum_rel_diff_max_in, & repro_sum_recompute_in, & repro_sum_master, & @@ -104,6 +111,8 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, & !------------------------------Arguments-------------------------------- ! Use DDPDD algorithm instead of fixed precision algorithm logical, intent(in), optional :: repro_sum_use_ddpdd_in + ! Allow INF or NaN in summands + logical, intent(in), optional :: repro_sum_allow_infnan_in ! maximum permissible difference between reproducible and ! nonreproducible sums real(r8), intent(in), optional :: repro_sum_rel_diff_max_in @@ -142,6 +151,9 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, & if ( present(repro_sum_use_ddpdd_in) ) then repro_sum_use_ddpdd = repro_sum_use_ddpdd_in endif + if ( present(repro_sum_allow_infnan_in) ) then + repro_sum_allow_infnan = repro_sum_allow_infnan_in + endif if ( present(repro_sum_rel_diff_max_in) ) then shr_reprosum_reldiffmax = repro_sum_rel_diff_max_in endif @@ -159,6 +171,14 @@ subroutine shr_reprosum_setopts(repro_sum_use_ddpdd_in, & 'distributed sum algorithm' endif + if ( repro_sum_allow_infnan ) then + write(logunit,*) 'SHR_REPROSUM_SETOPTS: ',& + 'Will calculate sum when INF or NaN are included in summands' + else + write(logunit,*) 'SHR_REPROSUM_SETOPTS: ',& + 'Will abort if INF or NaN are included in summands' + endif + if (shr_reprosum_reldiffmax >= 0._r8) then write(logunit,*) ' ',& 'with a maximum relative error tolerance of ', & @@ -185,7 +205,7 @@ end subroutine shr_reprosum_setopts ! subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & - nflds, ddpdd_sum, & + nflds, allow_infnan, ddpdd_sum, & arr_gbl_max, arr_gbl_max_out, & arr_max_levels, arr_max_levels_out, & gbl_max_nsummands, gbl_max_nsummands_out,& @@ -280,6 +300,10 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! use ddpdd algorithm instead ! of fixed precision algorithm + logical, intent(in), optional :: allow_infnan + ! if .true., allow INF or NaN input values. + ! if .false. (the default), then abort. + real(r8), intent(in), optional :: arr_gbl_max(nflds) ! upper bound on max(abs(arr)) @@ -312,13 +336,14 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! flag enabling/disabling testing that gmax and max_levels are ! accurate/sufficient. Default is enabled. - integer, intent(inout), optional :: repro_sum_stats(5) + integer, intent(inout), optional :: repro_sum_stats(6) ! increment running totals for ! (1) one-reduction repro_sum ! (2) two-reduction repro_sum ! (3) both types in one call ! (4) nonrepro_sum ! (5) global max nsummands reduction + ! (6) global lor 3*nflds reduction real(r8), intent(out), optional :: rel_diff(2,nflds) ! relative and absolute @@ -331,6 +356,8 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! ! Local workspace ! + logical :: abort_inf_nan ! flag indicating whether to + ! abort if INF or NaN found in input logical :: use_ddpdd_sum ! flag indicating whether to ! use shr_reprosum_ddpdd or not logical :: recompute ! flag indicating need to @@ -341,8 +368,23 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! are accurate/sufficient logical :: nan_check, inf_check ! flag on whether there are ! NaNs and INFs in input array + logical :: inf_nan_lchecks(3,nflds)! flags on whether there are + ! NaNs, positive INFs, or negative INFs + ! for each input field locally + logical :: inf_nan_gchecks(3,nflds)! flags on whether there are + ! NaNs, positive INFs, or negative INFs + ! for each input field + logical :: arr_gsum_infnan(nflds) ! flag on whether field sum is a + ! NaN or INF + + integer :: gbl_lor_red ! global lor reduction? (0/1) + integer :: gbl_max_red ! global max reduction? (0/1) + integer :: repro_sum_fast ! 1 reduction repro_sum? (0/1) + integer :: repro_sum_slow ! 2 reduction repro_sum? (0/1) + integer :: repro_sum_both ! both fast and slow? (0/1) + integer :: nonrepro_sum ! nonrepro_sum? (0/1) - integer :: num_nans, num_infs ! count of NaNs and INFs in + integer :: nan_count, inf_count ! local count of NaNs and INFs in ! input array integer :: omp_nthreads ! number of OpenMP threads integer :: mpi_comm ! MPI subcommunicator @@ -375,11 +417,6 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & integer :: max_levels(nflds) ! maximum number of levels of ! integer expansion to use integer :: max_level ! maximum value in max_levels - integer :: gbl_max_red ! global max local sum reduction? (0/1) - integer :: repro_sum_fast ! 1 reduction repro_sum? (0/1) - integer :: repro_sum_slow ! 2 reduction repro_sum? (0/1) - integer :: repro_sum_both ! both fast and slow? (0/1) - integer :: nonrepro_sum ! nonrepro_sum? (0/1) real(r8) :: xmax_nsummands ! dble of max_nsummands real(r8) :: arr_lsum(nflds) ! local sums @@ -396,38 +433,81 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! !----------------------------------------------------------------------- ! -! check whether input contains NaNs or INFs, and abort if so +! initialize local statistics variables + gbl_lor_red = 0 + gbl_max_red = 0 + repro_sum_fast = 0 + repro_sum_slow = 0 + repro_sum_both = 0 + nonrepro_sum = 0 - call t_startf('shr_reprosum_NaN_INF_Chk') - nan_check = .false. - inf_check = .false. - num_nans = 0 - num_infs = 0 +! set MPI communicator + if ( present(commid) ) then + mpi_comm = commid + else + mpi_comm = MPI_COMM_WORLD + endif + call t_barrierf('sync_repro_sum',mpi_comm) - nan_check = any(shr_infnan_isnan(arr)) - inf_check = any(shr_infnan_isinf(arr)) - if (nan_check .or. inf_check) then - do ifld=1,nflds - do isum=1,nsummands - if (shr_infnan_isnan(arr(isum,ifld))) then - num_nans = num_nans + 1 - endif - if (shr_infnan_isinf(arr(isum,ifld))) then - num_infs = num_infs + 1 - endif - end do - end do +! check whether should abort if input contains NaNs or INFs + abort_inf_nan = .not. repro_sum_allow_infnan + if ( present(allow_infnan) ) then + abort_inf_nan = .not. allow_infnan endif - call t_stopf('shr_reprosum_NaN_INF_Chk') - if ((num_nans > 0) .or. (num_infs > 0)) then - call mpi_comm_rank(MPI_COMM_WORLD, mypid, ierr) - write(s_logunit,37) real(num_nans,r8), real(num_infs,r8), mypid + call t_startf('shr_reprosum_INF_NaN_Chk') + +! initialize flags to indicate that no NaNs or INFs are present in the input data + inf_nan_gchecks = .false. + arr_gsum_infnan = .false. + + if (abort_inf_nan) then + +! check whether input contains NaNs or INFs, and abort if so + nan_check = any(shr_infnan_isnan(arr)) + inf_check = any(shr_infnan_isinf(arr)) + + if (nan_check .or. inf_check) then + + nan_count = count(shr_infnan_isnan(arr)) + inf_count = count(shr_infnan_isinf(arr)) + + if ((nan_count > 0) .or. (inf_count > 0)) then + call mpi_comm_rank(MPI_COMM_WORLD, mypid, ierr) + write(s_logunit,37) real(nan_count,r8), real(inf_count,r8), mypid 37 format("SHR_REPROSUM_CALC: Input contains ",e12.5, & " NaNs and ", e12.5, " INFs on process ", i7) - call shr_sys_abort("shr_reprosum_calc ERROR: NaNs or INFs in input") + call shr_sys_abort("shr_reprosum_calc ERROR: NaNs or INFs in input") + endif + + endif + + else + +! determine whether any fields contain NaNs or INFs, and avoid processing them +! via integer expansions + inf_nan_lchecks = .false. + + do ifld=1,nflds + inf_nan_lchecks(1,ifld) = any(shr_infnan_isnan(arr(:,ifld))) + inf_nan_lchecks(2,ifld) = any(shr_infnan_isposinf(arr(:,ifld))) + inf_nan_lchecks(3,ifld) = any(shr_infnan_isneginf(arr(:,ifld))) + end do + + call t_startf("repro_sum_allr_lor") + call mpi_allreduce (inf_nan_lchecks, inf_nan_gchecks, 3*nflds, & + MPI_LOGICAL, MPI_LOR, mpi_comm, ierr) + gbl_lor_red = 1 + call t_stopf("repro_sum_allr_lor") + + do ifld=1,nflds + arr_gsum_infnan(ifld) = any(inf_nan_gchecks(:,ifld)) + enddo + endif + call t_stopf('shr_reprosum_INF_NaN_Chk') + ! check whether should use shr_reprosum_ddpdd algorithm use_ddpdd_sum = repro_sum_use_ddpdd if ( present(ddpdd_sum) ) then @@ -439,21 +519,6 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! If not, always use ddpdd. use_ddpdd_sum = use_ddpdd_sum .or. (radix(0._r8) /= radix(0_i8)) -! initialize local statistics variables - gbl_max_red = 0 - repro_sum_fast = 0 - repro_sum_slow = 0 - repro_sum_both = 0 - nonrepro_sum = 0 - -! set MPI communicator - if ( present(commid) ) then - mpi_comm = commid - else - mpi_comm = MPI_COMM_WORLD - endif - call t_barrierf('sync_repro_sum',mpi_comm) - if ( use_ddpdd_sum ) then call t_startf('shr_reprosum_ddpdd') @@ -548,8 +613,8 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & endif call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, & nflds, arr_max_shift, arr_gmax_exp, & - arr_max_levels, max_level, validate, & - recompute, omp_nthreads, mpi_comm) + arr_max_levels, max_level, arr_gsum_infnan, & + validate, recompute, omp_nthreads, mpi_comm) ! record statistics, etc. repro_sum_fast = 1 @@ -598,13 +663,15 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & do ifld=1,nflds arr_exp_tlmin = MAXEXPONENT(1._r8) arr_exp_tlmax = MINEXPONENT(1._r8) - do isum=isum_beg(ithread),isum_end(ithread) - if (arr(isum,ifld) .ne. 0.0_r8) then - arr_exp = exponent(arr(isum,ifld)) - arr_exp_tlmin = min(arr_exp,arr_exp_tlmin) - arr_exp_tlmax = max(arr_exp,arr_exp_tlmax) - endif - end do + if (.not. arr_gsum_infnan(ifld)) then + do isum=isum_beg(ithread),isum_end(ithread) + if (arr(isum,ifld) .ne. 0.0_r8) then + arr_exp = exponent(arr(isum,ifld)) + arr_exp_tlmin = min(arr_exp,arr_exp_tlmin) + arr_exp_tlmax = max(arr_exp,arr_exp_tlmax) + endif + end do + endif arr_tlmin_exp(ifld,ithread) = arr_exp_tlmin arr_tlmax_exp(ifld,ithread) = arr_exp_tlmax end do @@ -628,9 +695,9 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & arr_gmax_exp(:) = -arr_gextremes(1:nflds,1) arr_gmin_exp(:) = arr_gextremes(1:nflds,2) -! if a field is identically zero, arr_gmin_exp still equals MAXEXPONENT -! and arr_gmax_exp still equals MINEXPONENT. In this case, set -! arr_gmin_exp = arr_gmax_exp = MINEXPONENT +! if a field is identically zero or contains INFs or NaNs, arr_gmin_exp +! still equals MAXEXPONENT and arr_gmax_exp still equals MINEXPONENT. +! In this case, set arr_gmin_exp = arr_gmax_exp = MINEXPONENT do ifld=1,nflds arr_gmin_exp(ifld) = min(arr_gmax_exp(ifld),arr_gmin_exp(ifld)) enddo @@ -695,10 +762,10 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & ! calculate sum validate = .false. - call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, nflds, & - arr_max_shift, arr_gmax_exp, max_levels, & - max_level, validate, recompute, & - omp_nthreads, mpi_comm) + call shr_reprosum_int(arr, arr_gsum, nsummands, dsummands, & + nflds, arr_max_shift, arr_gmax_exp, & + max_levels, max_level, arr_gsum_infnan, & + validate, recompute, omp_nthreads, mpi_comm) endif @@ -720,13 +787,17 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & !$omp default(shared) & !$omp private(ifld, isum) do ifld=1,nflds - do isum=1,nsummands - arr_lsum(ifld) = arr(isum,ifld) + arr_lsum(ifld) - end do + if (.not. arr_gsum_infnan(ifld)) then + do isum=1,nsummands + arr_lsum(ifld) = arr(isum,ifld) + arr_lsum(ifld) + end do + endif end do + call t_startf("nonrepro_sum_allr_r8") call mpi_allreduce (arr_lsum, arr_gsum_fast, nflds, & MPI_REAL8, MPI_SUM, mpi_comm, ierr) + call t_stopf("nonrepro_sum_allr_r8") call t_stopf('nonrepro_sum') @@ -748,6 +819,25 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & endif endif +! Set field sums to NaN and INF, as needed + do ifld=1,nflds + if (arr_gsum_infnan(ifld)) then + if (inf_nan_gchecks(1,ifld)) then + ! NaN => NaN + arr_gsum(ifld) = shr_infnan_nan + else if (inf_nan_gchecks(2,ifld) .and. inf_nan_gchecks(3,ifld)) then + ! posINF and negINF => NaN + arr_gsum(ifld) = shr_infnan_nan + else if (inf_nan_gchecks(2,ifld)) then + ! posINF only => posINF + arr_gsum(ifld) = shr_infnan_posinf + else if (inf_nan_gchecks(3,ifld)) then + ! negINF only => negINF + arr_gsum(ifld) = shr_infnan_neginf + endif + endif + end do + ! return statistics if ( present(repro_sum_stats) ) then repro_sum_stats(1) = repro_sum_stats(1) + repro_sum_fast @@ -755,6 +845,7 @@ subroutine shr_reprosum_calc (arr, arr_gsum, nsummands, dsummands, & repro_sum_stats(3) = repro_sum_stats(3) + repro_sum_both repro_sum_stats(4) = repro_sum_stats(4) + nonrepro_sum repro_sum_stats(5) = repro_sum_stats(5) + gbl_max_red + repro_sum_stats(6) = repro_sum_stats(6) + gbl_lor_red endif @@ -766,7 +857,7 @@ end subroutine shr_reprosum_calc subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, & arr_max_shift, arr_gmax_exp, max_levels, & - max_level, validate, recompute, & + max_level, skip_field, validate, recompute, & omp_nthreads, mpi_comm ) !---------------------------------------------------------------------- ! @@ -798,9 +889,14 @@ subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, & integer, intent(in) :: mpi_comm ! MPI subcommunicator real(r8), intent(in) :: arr(dsummands,nflds) - ! input array + ! input array - logical, intent(in):: validate + logical, intent(in) :: skip_field(nflds) + ! flag indicating whether the sum for this field should be + ! computed or not (used to skip over fields containing + ! NaN or INF summands) + + logical, intent(in) :: validate ! flag indicating that accuracy of solution generated from ! arr_gmax_exp and max_levels should be tested @@ -920,8 +1016,10 @@ subroutine shr_reprosum_int (arr, arr_gsum, nsummands, dsummands, nflds, & max_error(ifld,ithread) = 0 not_exact(ifld,ithread) = 0 - i8_arr_tlsum_level(:,ifld,ithread) = 0_i8 + + if (skip_field(ifld)) cycle + do isum=isum_beg(ithread),isum_end(ithread) arr_remainder = 0.0_r8 @@ -1370,8 +1468,11 @@ subroutine shr_reprosum_ddpdd (arr, arr_gsum, nsummands, dsummands, & enddo + call t_startf("repro_sum_allr_c16") call mpi_allreduce (arr_lsum_dd, arr_gsum_dd, nflds, & MPI_COMPLEX16, mpi_sumdd, mpi_comm, ierr) + call t_stopf("repro_sum_allr_c16") + do ifld=1,nflds arr_gsum(ifld) = real(arr_gsum_dd(ifld)) enddo