Skip to content

Commit

Permalink
stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Nov 10, 2023
1 parent af28546 commit 05c0b1d
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ DynamicGridsMakieExt = "Makie"

[compat]
Adapt = "2, 3"
CUDA = "5"
CUDA = "4, 5"
Colors = "0.9, 0.10, 0.11, 0.12"
ConstructionBase = "1"
Crayons = "4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Ruleset
BoundaryCondition
Wrap
Remove
Ignore
```

### Hardware selection
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicGrids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export Processor, SingleCPU, ThreadedCPU, CPUGPU, CuGPU

export PerformanceOpt, NoOpt, SparseOpt

export BoundaryCondition, Remove, Wrap
export BoundaryCondition, Remove, Wrap, Ignore

export ParameterSource, Aux, Grid, Delay, Lag, Frame

Expand Down
18 changes: 18 additions & 0 deletions src/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ end
@propagate_inbounds function _setindex!(d::GridData{<:SwitchMode}, opt::GPU, x, I...)
dest(d)[I...] = x
end

function _maybemask!(
wgrid::GridData{<:GridMode,<:Tuple{Y,X}}, proc::GPU, mask::AbstractArray
) where {Y,X}
pv = padval(wgrid)
kernel! = ka_mask_kernel!(kernel_setup(proc)...)
kernel!(source(wgrid), mask, pv; ndrange=size(wgrid))
return nothing
end

@kernel function ka_mask_kernel!(grid, mask, padval)
I = @index(Global, NTuple)
mask[I...] ? grid[I...] : padval
end
@kernel function ka_mask_kernel!(grid, mask, padval::Nothing)
I = @index(Global, NTuple)
mask[I...] * grid[I...]
end
49 changes: 34 additions & 15 deletions src/maprules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ function maprule!(
# UNSAFE: we must avoid sharing status blocks, it could cause race conditions
# when setting status from different threads. So we split the grid in 2 interleaved
# sets of rows, so that we never run adjacent rows simultaneously
map_on_processor(proc, 1:2:_indtoblock(Y, B)) do bi
map_on_processor(proc, data, 1:2:_indtoblock(Y, B)) do bi
row_kernel!(data, hoodgrid, proc, opt, ruletype, rule, rkeys, wkeys, bi)
end
map_on_processor(proc, 2:2:_indtoblock(Y, B)) do bi
map_on_processor(proc, data, 2:2:_indtoblock(Y, B)) do bi
row_kernel!(data, hoodgrid, proc, opt, ruletype, rule, rkeys, wkeys, bi)
end
end
Expand All @@ -122,7 +122,7 @@ end
function map_with_optimisation(
f, simdata::AbstractSimData{S}, proc, ::NoOpt, ::Val{<:Rule}, rkeys
) where S<:Tuple{Y,X} where {Y,X}
map_on_processor(proc, 1:X) do j
map_on_processor(proc, simdata, 1:X) do j
for i in 1:Y
f((i, j)) # Run rule for each row in column j
end
Expand All @@ -132,7 +132,7 @@ end
function map_with_optimisation(
f, simdata::AbstractSimData{S}, proc, ::NoOpt, ::Val{<:CellRule}, rkeys
) where S<:Tuple{Y,X} where {Y,X}
map_on_processor(proc, 1:X) do j
map_on_processor(proc, simdata, 1:X) do j
@simd for i in 1:Y
f((i, j)) # Run rule for each row in column j
end
Expand All @@ -143,30 +143,51 @@ end
# Map kernel over the grid, specialising on the processor
#
# Looping over cells or blocks on a single CPU
@inline function map_on_processor(f, proc::SingleCPU, range)
@inline function map_on_processor(f, proc::SingleCPU, data, range)
for n in range
f(n) # Run rule over each column
end
end
# Or threaded on multiple CPUs
@inline function map_on_processor(f, proc::ThreadedCPU, range)
Threads.@threads for n in range
f(n) # Run rule over each column, threaded
@inline function map_on_processor(f, proc::ThreadedCPU, data, rnge)
# We don't want to share memory between neighborhoods
min_cols = max(3, 2radius(data) + 1)
N = Threads.nthreads()
allchunks = collect(Iterators.partition(rnge, min_cols))
chunks = map(1:N) do i
allchunks[i:N:end]
end
tasks = map(chunks) do chunk
Threads.@spawn begin
for subchunk in chunk
for n in subchunk
f(n)
end
end
end
end
states = fetch.(tasks)
return nothing
end

# cell_kernel!
# runs a rule for the current cell
@inline function cell_kernel!(simdata, ruletype::Val{<:Rule}, rule, rkeys, wkeys, I...)
if !isnothing(mask(simdata))
mask(simdata)[I...] || return nothing
end
readval = _readcell(simdata, rkeys, I...)
writeval = applyrule(simdata, rule, readval, I)
_writecell!(simdata, ruletype, wkeys, writeval, I...)
writeval
return writeval
end
@inline function cell_kernel!(simdata, ::Val{<:SetRule}, rule, rkeys, wkeys, I...)
if !isnothing(mask(simdata))
mask(simdata)[I...] || return nothing
end
readval = _readcell(simdata, rkeys, I...)
applyrule!(simdata, rule, readval, I)
nothing
return nothing
end

# stencil_kernel!
Expand All @@ -185,7 +206,7 @@ end
# data in the stencil windows array across by one column. This saves on reads
# from the main array.
function row_kernel!(
simdata::AbstractSimData, grid::GridData{<:Any,<:Tuple{Y,X},R}, proc, opt::NoOpt,
simdata::AbstractSimData, grid::GridData{<:GridMode,<:Tuple{Y,X},R}, proc, opt::NoOpt,
ruletype::Val, rule::Rule, rkeys, wkeys, bi
) where {Y,X,R}
B = 2R
Expand All @@ -194,11 +215,9 @@ function row_kernel!(
# Loop along the block ROW.
blocklen = min(Y, i + B - 1) - i + 1
for j = 1:X
# windows = _slide_windows(windows, src, Val{R}(), i, j)
# Loop over the COLUMN of windows covering the block
for b in 1:blocklen
rule1 = Stencils.rebuild(rule, unsafe_neighbors(stencil(rule), grid, CartesianIndex(i, j)))
cell_kernel!(simdata, ruletype, rule1, rkeys, wkeys, i + b - 1, j)
stencil_kernel!(simdata, grid, ruletype, rule, rkeys, wkeys, i + b - 1, j)
end
end
return nothing
Expand All @@ -221,7 +240,7 @@ function _maybemask!(
) where {Y,X}
A = source(wgrid)
mv = maskval(wgrid)
map_on_processor(proc, 1:X) do j
map_on_processor(proc, wgrid, 1:X) do j
if isnothing(mv) || mv == zero(eltype(wgrid))
@simd for i in 1:Y
A[i, j] *= mask[i, j]
Expand Down
50 changes: 17 additions & 33 deletions src/parametersources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,45 +133,29 @@ end
# matching type to the simulation tspan.
# This is called from _updatetime in simulationdata.jl
_calc_auxframe(data::AbstractSimData) = _calc_auxframe(aux(data), data)
function _calc_auxframe(aux::NamedTuple, data::AbstractSimData)
map(A -> _calc_auxframe(A, data), aux)
function _calc_auxframe(aux::NamedTuple{K}, data::AbstractSimData) where K
map((A, k) -> _calc_auxframe(A, data, k), aux, NamedTuple{K}(K))
end
function _calc_auxframe(A::AbstractDimArray, data)
function _calc_auxframe(A::AbstractDimArray, data, key)
hasdim(A, TimeDim) || return nothing
timedim = dims(A, TimeDim)
curtime = currenttime(data)
firstauxtime = first(timedim)
# For Irregular we use `Contains` to get the nearest matching timestep
if span(timedim) isa Irregular
return DimensionalData.selectindices(timedim, Contains(curtime))
if !hasselection(timedim, Near(curtime))
if lookup(timedim) isa Cyclic
if sampling(timedim) isa Points
throw(ArgumentError("Time dimension of aux `$key` has no valid selection for `Contains($curtime)`. Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact."))
else
throw(ArgumentError("Time dimension of aux `$key` has no valid selection for `Contains($curtime)`."))
end
elseif sampling(timedim) isa Points
throw(ArgumentError("Time dimension of aux `$key` has no valid selection for `Contains($curtime)`. Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact."))
else
throw(ArgumentError("aux `$key` has no valid selection for `Contains($curtime)`. Did you mean to use a `Cyclic` lookup for the time dimension of the array?"))
end
end
auxstep = step(timedim)
# Use julias range objects to calculate the distance between the
# current time and the start of the aux
i = if curtime >= firstauxtime
length(firstauxtime:auxstep:curtime)
else
1 - length(firstauxtime-timestep(data):-auxstep:curtime)
end
# TODO use a cyclic mode DimensionalArray
# and handle the mismatch of e.g. weeks and years
return _cyclic_index(i, size(A, Ti))
return DimensionalData.selectindices(timedim, Near(curtime))
end
_calc_auxframe(aux, data) = nothing

# _cyclic_index
# Cycle an index over the length of the aux data.
function _cyclic_index(i::Integer, len::Integer)
return if i > len
rem(i + len - 1, len) + 1
elseif i <= 0
i + (i ÷ len -1) * -len
else
i
end
end


_calc_auxframe(aux, data, key) = nothing

"""
Grid <: ParameterSource
Expand Down
8 changes: 6 additions & 2 deletions src/simulationdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ aux(d::AbstractSimData, args...) = aux(extent(d), args...)
auxframe(d::AbstractSimData, key) = auxframe(d)[_unwrap(key)]
tspan(d::AbstractSimData) = tspan(extent(d))
timestep(d::AbstractSimData) = step(tspan(d))
radius(d::AbstractSimData) = max(map(radius, grids(d))...)

# Calculated:
# Get the current time for this frame
Expand Down Expand Up @@ -179,6 +180,7 @@ end

_boundary(::Wrap, padval) = Wrap()
_boundary(::Remove, padval) = Remove(padval)
_boundary(::Ignore, padval) = Ignore()

ConstructionBase.constructorof(::Type{<:SimData{S,N}}) where {S,N} = SimData{S,N}

Expand Down Expand Up @@ -234,8 +236,10 @@ function RuleData{S,N}(
) where {S,N,G,E,Se,F,CF,AF}
RuleData{S,N,G,E,Se,F,CF,AF}(grids, extent, settings, frames, currentframe, auxframe)
end
function RuleData(d::AbstractSimData{S,N}) where {S,N}
RuleData{S,N}(grids(d), extent(d), settings(d), frames(d), currentframe(d), auxframe(d))
function RuleData(d::AbstractSimData{S,N};
grids=grids(d), extent=extent(d), settings=settings(d), frames=frames(d), currentframe=currentframe(d), auxframe=auxframe(d)
) where {S,N}
return RuleData{S,N}(grids, extent, settings, frames, currentframe, auxframe)
end

ConstructionBase.constructorof(::Type{<:RuleData{S,N}}) where {S,N} = RuleData{S,N}
Expand Down
39 changes: 19 additions & 20 deletions test/parametersources.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
using DynamicGrids, Dates, DimensionalData, Setfield, Unitful, Test
using DynamicGrids, Dates, DimensionalData, Setfield, Unitful, Test, DimensionalData
using Unitful: d
using DynamicGrids: SimData, Extent, _calc_auxframe, _cyclic_index
using DynamicGrids: SimData, Extent, _calc_auxframe
const DG = DynamicGrids

@testset "Aux" begin
@testset "sequence cycling" begin
@test _cyclic_index(-4, 2) == 2
@test _cyclic_index(-3, 2) == 1
@test _cyclic_index(-2, 2) == 2
@test _cyclic_index(-1, 2) == 1
@test _cyclic_index(0, 2) == 2
@test _cyclic_index(1, 2) == 1
@test _cyclic_index(2, 2) == 2
@test _cyclic_index(3, 2) == 1
@test _cyclic_index(4, 2) == 2
@test _cyclic_index(20, 10) == 10
@test _cyclic_index(21, 10) == 1
@test _cyclic_index(27, 10) == 7
end

@testset "aux sequence" begin
a = cat([0.1 0.2; 0.3 0.4], [1.1 1.2; 1.3 1.4], [2.1 2.2; 2.3 2.4]; dims=3)

@testset "the correct frame is calculated for aux data" begin
dimz = X(1:2), Y(1:2), Ti(15d:5d:25d)
dimz = X(1:2), Y(1:2), Ti(DimensionalData.Cyclic(1d:5d:14d; order=ForwardOrdered(), cycle=15d, sampling=Intervals(Start())))
seq = DimArray(a, dimz)
init = zero(seq[Ti(1)])
sd = SimData(Extent(init=init, aux=(seq=seq,), tspan=1d:1d:100d), Ruleset())
@test DynamicGrids.boundscheck_aux(sd, Aux{:seq}()) == true
tests = (1, 1), (4, 1), (5, 2), (6, 2), (9, 2), (10, 3), (11, 3), (14, 3), (15, 1),
(19, 1), (20, 2), (25, 3), (29, 3), (30, 1), (34, 1), (35, 2)
tests = (1, 1), (4, 1), (5, 1), (6, 2), (9, 2), (10, 2), (11, 3), (14, 3), (15, 3),
(20, 1), (21, 2), (26, 3), (30, 3), (31, 1), (35, 1), (36, 2)
for (f, ref_af) in tests
@set! sd.currentframe = f
af = _calc_auxframe(sd).seq
@test af == ref_af
end
# Not Cycled
dimz = X(1:2), Y(1:2), Ti(1d:5d:14d; order=ForwardOrdered(), cycle=15d, sampling=Intervals(Start()))
seq = DimArray(a, dimz)
init = zero(seq[Ti(1)])
sd = SimData(Extent(init=init, aux=(seq=seq,), tspan=1d:1d:100d), Ruleset())
@set! sd.currentframe = 20
@test_throws ArgumentError _calc_auxframe(sd).seq
# Not Intervals
dimz = X(1:2), Y(1:2), Ti(DimensionalData.Cyclic(1d:5d:14d; order=ForwardOrdered(), cycle=15d, sampling=Points()))
seq = DimArray(a, dimz)
init = zero(seq[Ti(1)])
sd = SimData(Extent(init=init, aux=(seq=seq,), tspan=1d:1d:100d), Ruleset())
@set! sd.currentframe = 7
@test_throws ArgumentError _calc_auxframe(sd).seq
end

@testset "boundscheck_aux" begin
Expand Down
3 changes: 1 addition & 2 deletions test/wrappers/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ using DynamicGrids: SimData, radius, rules, _readkeys, _writekeys,

@test radius(ruleset) == (b=0, c=0, d=0, e=0, a=0)

@test applyrule(data, chain, (b=1, c=1, d=1, a=1), (1, 1)) ==
(4, 6, 10, 3, 2)
@test applyrule(data, chain, (b=1, c=1, d=1, a=1), (1, 1)) == (4, 6, 10, 3, 2)

# @inferred applyrule(data, chain, (b=1, c=1, d=1, a=1), (1, 1))

Expand Down

0 comments on commit 05c0b1d

Please sign in to comment.