-
Notifications
You must be signed in to change notification settings - Fork 8
Implement neighborhood search based on CellListMap.jl #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
82350dd
f7f48e8
3002d46
357ebd2
c8f1ce2
1d1c078
d06e48f
67d5422
e3a9637
b874b5f
0ba9dad
aad3e13
e6e374f
cbf25c5
7779a7e
4e4ec1d
39ac331
fbf2e99
d7000c8
1326b52
d183c14
24f0c85
b5aca7d
58f5b57
fabafbb
2cc18a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,189 @@ | ||||||
module PointNeighborsCellListMapExt | ||||||
|
||||||
using PointNeighbors | ||||||
using CellListMap: CellListMap, CellList, CellListPair | ||||||
|
||||||
""" | ||||||
CellListMapNeighborhoodSearch(NDIMS; search_radius = 1.0, points_equal_neighbors = false) | ||||||
|
||||||
Neighborhood search based on the package [CellListMap.jl](https://github.com/m3g/CellListMap.jl). | ||||||
This package provides a similar implementation to the [`GridNeighborhoodSearch`](@ref) | ||||||
with [`FullGridCellList`](@ref), but with better support for periodic boundaries. | ||||||
This is just a wrapper to use CellListMap.jl with the PointNeighbors.jl API. | ||||||
Note that periodic boundaries are not yet supported. | ||||||
|
||||||
# Arguments | ||||||
- `NDIMS`: Number of dimensions. | ||||||
|
||||||
# Keywords | ||||||
- `search_radius = 1.0`: The fixed search radius. The default of `1.0` is useful together | ||||||
with [`copy_neighborhood_search`](@ref). | ||||||
- `points_equal_neighbors = false`: If `true`, a `CellListMap.CellList` is used instead of | ||||||
a `CellListMap.CellListPair`. This requires that `x === y` | ||||||
in [`initialize!`](@ref) and [`update!`](@ref). | ||||||
This option exists only for benchmarking purposes. | ||||||
It makes the main loop awkward because CellListMap.jl | ||||||
only computes pairs with `i < j` and PointNeighbors.jl | ||||||
computes all pairs, so we have to manually use symmetry | ||||||
to add the missing pairs. | ||||||
|
||||||
!!! warning "Experimental implementation" | ||||||
This is an experimental feature and may change in future releases. | ||||||
""" | ||||||
mutable struct CellListMapNeighborhoodSearch{CL, B} | ||||||
cell_list::CL | ||||||
# Note that we need this struct to be mutable to replace the box in `update!` | ||||||
box::B | ||||||
|
||||||
# Add dispatch on `NDIMS` to avoid method overwriting of the function in PointNeighbors.jl | ||||||
function PointNeighbors.CellListMapNeighborhoodSearch(NDIMS::Integer; | ||||||
search_radius = 1.0, | ||||||
points_equal_neighbors = false) | ||||||
# Create a cell list with only one point and resize it later | ||||||
x = zeros(NDIMS, 1) | ||||||
box = CellListMap.Box(CellListMap.limits(x, x), search_radius) | ||||||
|
||||||
if points_equal_neighbors | ||||||
cell_list = CellListMap.CellList(x, box) | ||||||
else | ||||||
cell_list = CellListMap.CellList(x, x, box) | ||||||
end | ||||||
|
||||||
return new{typeof(cell_list), typeof(box)}(cell_list, box) | ||||||
end | ||||||
end | ||||||
|
||||||
function PointNeighbors.search_radius(neighborhood_search::CellListMapNeighborhoodSearch) | ||||||
return neighborhood_search.box.cutoff | ||||||
end | ||||||
|
||||||
function Base.ndims(neighborhood_search::CellListMapNeighborhoodSearch) | ||||||
return length(neighborhood_search.box.cell_size) | ||||||
end | ||||||
|
||||||
function PointNeighbors.initialize!(neighborhood_search::CellListMapNeighborhoodSearch, | ||||||
x::AbstractMatrix, y::AbstractMatrix) | ||||||
PointNeighbors.update!(neighborhood_search, x, y) | ||||||
end | ||||||
|
||||||
# When `x !== y`, a `CellListPair` must be used | ||||||
function PointNeighbors.update!(neighborhood_search::CellListMapNeighborhoodSearch{<:CellListPair}, | ||||||
x::AbstractMatrix, y::AbstractMatrix; | ||||||
points_moving = (true, true)) | ||||||
(; cell_list) = neighborhood_search | ||||||
|
||||||
# Resize box | ||||||
box = CellListMap.Box(CellListMap.limits(x, y), neighborhood_search.box.cutoff) | ||||||
neighborhood_search.box = box | ||||||
|
||||||
# Resize and update cell list | ||||||
CellListMap.UpdateCellList!(x, y, box, cell_list) | ||||||
|
||||||
# Recalculate number of batches for multithreading | ||||||
CellListMap.set_number_of_batches!(cell_list) | ||||||
|
||||||
return neighborhood_search | ||||||
end | ||||||
|
||||||
# When `points_equal_neighbors == true`, a `CellList` is used and `x` must be equal to `y` | ||||||
function PointNeighbors.update!(neighborhood_search::CellListMapNeighborhoodSearch{<:CellList}, | ||||||
x::AbstractMatrix, y::AbstractMatrix; | ||||||
points_moving = (true, true)) | ||||||
(; cell_list) = neighborhood_search | ||||||
|
||||||
@assert x===y "When `points_equal_neighbors == true`, `x` must be equal to `y`" | ||||||
|
||||||
# Resize box | ||||||
box = CellListMap.Box(CellListMap.limits(x), neighborhood_search.box.cutoff) | ||||||
neighborhood_search.box = box | ||||||
|
||||||
# Resize and update cell list | ||||||
CellListMap.UpdateCellList!(x, box, cell_list) | ||||||
|
||||||
# Recalculate number of batches for multithreading | ||||||
CellListMap.set_number_of_batches!(cell_list) | ||||||
|
||||||
# Due to https://github.com/m3g/CellListMap.jl/issues/106, we have to update again | ||||||
CellListMap.UpdateCellList!(x, box, cell_list) | ||||||
|
||||||
return neighborhood_search | ||||||
end | ||||||
|
||||||
# The type annotation is to make Julia specialize on the type of the function. | ||||||
# Otherwise, unspecialized code will cause a lot of allocations | ||||||
# and heavily impact performance. | ||||||
# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing | ||||||
function PointNeighbors.foreach_point_neighbor(f::T, system_coords, neighbor_coords, | ||||||
neighborhood_search::CellListMapNeighborhoodSearch{<:CellListPair}; | ||||||
points = axes(system_coords, 2), | ||||||
parallel = true) where {T} | ||||||
(; cell_list, box) = neighborhood_search | ||||||
|
||||||
# `0` is the returned output, which we don't use. | ||||||
# Note that `parallel !== false` is `true` when `parallel` is a PointNeighbors backend. | ||||||
CellListMap.map_pairwise!(0, box, cell_list, | ||||||
parallel = parallel !== false) do x, y, i, j, d2, output | ||||||
LasNikas marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Skip all indices not in `points` | ||||||
i in points || return output | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Doesn't this produce a notable overhead which affects the benchmark? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried. It's negligible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just out of interest. Am I missing something? julia> f(i, points) = i in points
f (generic function with 1 method)
julia> points = rand(Int, 1_000_000);
julia> @benchmark f($5, $points)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 221.584 μs … 380.625 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 221.834 μs ┊ GC (median): 0.00%
Time (mean ± σ): 223.171 μs ± 4.460 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
█▅▂▄▁ ▅ ▁ ▁ ▁
█████▇▆██▇██▆██▇▆▆▆▇▅▆▆▅▅▆▅▇▅▄▄▃▅▄▄▅▃▃▄▃▄▂▂▃▃▃▃▂▃▂▂▄▅▇▆▄▃▅▂▃▇ █
222 μs Histogram: log(frequency) by time 243 μs <
Memory estimate: 0 bytes, allocs estimate: 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that this is about 1-2% of the actual computation when you have a real example where you are doing actual work per particle. |
||||||
|
||||||
pos_diff = x - y | ||||||
distance = sqrt(d2) | ||||||
|
||||||
@inline f(i, j, pos_diff, distance) | ||||||
|
||||||
return output | ||||||
end | ||||||
|
||||||
return nothing | ||||||
end | ||||||
|
||||||
function PointNeighbors.foreach_point_neighbor(f::T, system_coords, neighbor_coords, | ||||||
neighborhood_search::CellListMapNeighborhoodSearch{<:CellList}; | ||||||
points = axes(system_coords, 2), | ||||||
parallel = true) where {T} | ||||||
(; cell_list, box) = neighborhood_search | ||||||
|
||||||
# `0` is the returned output, which we don't use. | ||||||
# Note that `parallel !== false` is `true` when `parallel` is a PointNeighbors backend. | ||||||
CellListMap.map_pairwise!(0, box, cell_list, | ||||||
parallel = parallel !== false) do x, y, i, j, d2, output | ||||||
# Skip all indices not in `points` | ||||||
i in points || return output | ||||||
|
||||||
pos_diff = x - y | ||||||
distance = sqrt(d2) | ||||||
|
||||||
# When `points_equal_neighbors == true`, a `CellList` is used. | ||||||
# With a `CellList`, we only see each pair once and have to use symmetry manually. | ||||||
@inline f(i, j, pos_diff, distance) | ||||||
@inline f(j, i, -pos_diff, distance) | ||||||
|
||||||
return output | ||||||
end | ||||||
|
||||||
# With a `CellList`, only pairs with `i < j` are considered. | ||||||
# We can cover `i > j` with symmetry above, but `i = j` has to be computed separately. | ||||||
PointNeighbors.@threaded system_coords for point in points | ||||||
zero_pos_diff = zero(PointNeighbors.SVector{ndims(neighborhood_search), | ||||||
eltype(system_coords)}) | ||||||
@inline f(point, point, zero_pos_diff, zero(eltype(system_coords))) | ||||||
end | ||||||
|
||||||
return nothing | ||||||
end | ||||||
|
||||||
function PointNeighbors.copy_neighborhood_search(nhs::CellListMapNeighborhoodSearch{<:CellListPair}, | ||||||
search_radius, n_points; | ||||||
eachpoint = 1:n_points) | ||||||
return PointNeighbors.CellListMapNeighborhoodSearch(ndims(nhs); search_radius, | ||||||
points_equal_neighbors = false) | ||||||
end | ||||||
|
||||||
function PointNeighbors.copy_neighborhood_search(nhs::CellListMapNeighborhoodSearch{<:CellList}, | ||||||
search_radius, n_points; | ||||||
eachpoint = 1:n_points) | ||||||
return PointNeighbors.CellListMapNeighborhoodSearch(ndims(nhs); search_radius, | ||||||
points_equal_neighbors = true) | ||||||
end | ||||||
|
||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
[deps] | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
TrixiParticles = "66699cd8-9c01-4e9d-a059-b96c86d16b3a" | ||
|
||
[compat] | ||
BenchmarkTools = "1" | ||
CellListMap = "0.9" | ||
Plots = "1" | ||
Test = "1" | ||
TrixiParticles = "0.2" |
Uh oh!
There was an error while loading. Please reload this page.