Skip to content

Commit

Permalink
Improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Nov 15, 2024
1 parent 3ac87fb commit 8ce6aa1
Showing 1 changed file with 132 additions and 29 deletions.
161 changes: 132 additions & 29 deletions src/nhs_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,79 +395,184 @@ end
return nothing
end

@inline function copy_to_localmem!(local_points, local_neighbor_coords,
neighbor_cell, neighbor_system_coords,
neighborhood_search, particleidx)
points_view = points_in_cell(neighbor_cell, neighborhood_search)
n_particles_in_neighbor_cell = length(points_view)

# First use all threads to load the neighbors into local memory in parallel
if particleidx <= n_particles_in_neighbor_cell
@inbounds p = local_points[particleidx] = points_view[particleidx]
for d in 1:ndims(neighborhood_search)
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
end
end
return n_particles_in_neighbor_cell
end

# @parallel(block) for cell in cells
# for neighbor_cell in neighboring_cells
# @parallel(thread) for neighbor in neighbor_cell
# copy_coordinates_to_localmem(neighbor)
#
# # Make sure all threads finished the copying
# @synchronize
#
# @parallel(thread) for particle in cell
# for neighbor in neighbor_cell
# # This uses the neighbor coordinates from the local memory
# compute(point, neighbor)
#
# # Make sure all threads finished computing before we continue with copying
# @synchronize
@kernel cpu=false function foreach_neighbor_localmem(f::F, system_coords, neighbor_system_coords,
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
cell_ = @index(Group)
cell = @inbounds Tuple(cells[cell_])
particleidx = @index(Local)
@assert 1 <= particleidx <= MAX

# Coordinate buffer in local memory
local_points = @localmem Int32 MAX
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)

next_local_points = @localmem Int32 MAX
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
points = points_in_cell(cell, neighborhood_search)
n_particles_in_current_cell = length(points)

pv = points_in_cell(cell, neighborhood_search)
n_particles_in_current_cell = length(pv)
# Extract point coordinates if a point lies on this thread
if particleidx <= n_particles_in_current_cell
point = @inbounds pv[particleidx]
point = @inbounds points[particleidx]
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
point)
else
point = zero(Int32)
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
end

@inline function stage!(local_points, local_neighbor_coords, neighbor_cell)
points_view = points_in_cell(neighbor_cell, neighborhood_search)
n_particles_in_neighbor_cell_ = length(points_view)
for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
neighbor_cell = Tuple(neighbor_cell_)

n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
neighbor_cell, neighbor_system_coords,
neighborhood_search, particleidx)

# Make sure all threads finished the copying
@synchronize

# First use all threads to load the neighbors into local memory in parallel
if particleidx <= n_particles_in_neighbor_cell_
@inbounds p = local_points[particleidx] = points_view[particleidx]
for d in 1:ndims(neighborhood_search)
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
# Now each thread works on one point again
if particleidx <= n_particles_in_current_cell
for local_neighbor in 1:n_particles_in_neighbor_cell
@inbounds neighbor = local_points[local_neighbor]
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
Val(ndims(neighborhood_search)),
local_neighbor)

pos_diff = point_coords - neighbor_coords
distance2 = dot(pos_diff, pos_diff)

# TODO periodic

if distance2 <= search_radius^2
distance = sqrt(distance2) # TODO: eventuell fastmath

# Inline to avoid loss of performance
# compared to not using `foreach_point_neighbor`.
@inline f(point, neighbor, pos_diff, distance)
end
end
end
return n_particles_in_neighbor_cell_

# Make sure all threads finished computing before we continue with copying
@synchronize()
end
end

# @parallel(block) for cell in cells
# @parallel(thread) for neighbor in first_neighbor_cell
# copy_coordinates_to_localmem!(local_coords, neighbor)
#
# for neighbor_cell in neighboring_cells
# @parallel(thread) for neighbor in neighbor_cell + 1
# copy_coordinates_to_localmem!(next_local_coords, neighbor)
#
# # No synchronize needed. The following loop works on `local_coords`.
#
# @parallel(thread) for particle in cell
# for neighbor in neighbor_cell
# # This uses the neighbor coordinates from the local memory
# compute(point, neighbor)
#
# # Make sure all threads finished computing before we switch variables
# @synchronize
# local_coords, next_local_coords = next_local_coords, local_coords
@kernel cpu=false function foreach_neighbor_double_buffer(f::F, system_coords, neighbor_system_coords,
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
cell_ = @index(Group)
cell = @inbounds Tuple(cells[cell_])
particleidx = @index(Local)
@assert 1 <= particleidx <= MAX

# Coordinate buffer in local memory
local_points = @localmem Int32 MAX
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)

# Next coordinate buffer in local memory
next_local_points = @localmem Int32 MAX
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)

points = points_in_cell(cell, neighborhood_search)
n_particles_in_current_cell = length(points)

# Extract point coordinates if a point lies on this thread
if particleidx <= n_particles_in_current_cell
point = @inbounds points[particleidx]
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
point)
else
point = zero(Int32)
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
end

neighborhood = neighboring_cells(cell, neighborhood_search)
# (neighbor_cell, state) = iterate(neighborhood)
neighbor_cell = first(neighborhood)
neighbor_cell = Tuple(first(neighborhood))

n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
neighbor_cell, neighbor_system_coords,
neighborhood_search, particleidx)

n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
@synchronize()

for neighbor_ in 1:length(neighborhood)
neighbor_cell = @inbounds neighborhood[neighbor_]
neighbor_cell = @inbounds Tuple(neighborhood[neighbor_])

# while true
# next = iterate(neighborhood, state)
# if next !== nothing
# n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
# @synchronize

if neighbor_ < length(neighborhood)
next_neighbor_cell = neighborhood[neighbor_ + 1]
next_neighbor_cell = @inbounds Tuple(neighborhood[neighbor_ + 1])
# (next_neighbor_cell, state) = next
next_n_particles_in_neighbor_cell = stage!(next_local_points, next_local_neighbor_coords, Tuple(next_neighbor_cell))
next_n_particles_in_neighbor_cell = copy_to_localmem!(next_local_points, next_local_neighbor_coords,
next_neighbor_cell, neighbor_system_coords,
neighborhood_search, particleidx)
end

# Now each thread works on one point again
if particleidx <= n_particles_in_current_cell
for local_neighbor in 1:n_particles_in_neighbor_cell
@inbounds neighbor = local_points[local_neighbor]
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
Val(ndims(neighborhood_search)), local_neighbor)
Val(ndims(neighborhood_search)),
local_neighbor)

pos_diff = point_coords - neighbor_coords
distance2 = dot(pos_diff, pos_diff)

# TODO periodic

if distance2 <= search_radius^2
# KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
distance = sqrt(distance2) # TODO: eventuell fastmath

# Inline to avoid loss of performance
Expand All @@ -476,17 +581,15 @@ end
end
end
end

# next === nothing && break
neighbor_ >= length(neighborhood) && break
@synchronize()

# swap variables
n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell
temp = local_points
local_points = next_local_points
next_local_points = temp
temp = local_neighbor_coords
local_neighbor_coords = next_local_neighbor_coords
next_local_neighbor_coords = temp
local_points, next_local_points = next_local_points, local_points
local_neighbor_coords, next_local_neighbor_coords = next_local_neighbor_coords, local_neighbor_coords
end
end

Expand Down

0 comments on commit 8ce6aa1

Please sign in to comment.