Skip to content

Commit 21c2bca

Browse files
committed
datadeps: Add at-stencil helper
1 parent 348b2a5 commit 21c2bca

File tree

5 files changed

+424
-2
lines changed

5 files changed

+424
-2
lines changed

docs/make.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ makedocs(;
2626
"Scopes" => "scopes.md",
2727
"Processors" => "processors.md",
2828
"Task Queues" => "task-queues.md",
29-
"Datadeps" => "datadeps.md",
29+
"Datadeps" => [
30+
"Basics" => "datadeps.md",
31+
"Stencils" => "stencils.md",
32+
],
3033
"Option Propagation" => "propagation.md",
3134
"Logging and Visualization" => [
3235
"Logging: Basics" => "logging.md",

docs/src/stencils.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Stencil Operations
2+
3+
4+
5+
```julia
6+
N = 27
7+
nt = 3
8+
tiles = zeros(Blocks(N, N), Bool, N*nt, N*nt)
9+
outputs = zeros(Blocks(N, N), Bool, N*nt, N*nt)
10+
11+
# Create fun initial state
12+
tiles[13, 14] = 1
13+
tiles[14, 14] = 1
14+
tiles[15, 14] = 1
15+
tiles[15, 15] = 1
16+
tiles[14, 16] = 1
17+
@view(tiles[(2N+1):3N, (2N+1):3N]) .= rand(Bool, N, N)
18+
19+
import Dagger: @stencil, Wrap
20+
21+
anim = @animate for _ in 1:niters
22+
Dagger.spawn_datadeps() do
23+
@stencil begin
24+
outputs[idx] = begin
25+
nhood = @neighbors(tiles[idx], 1, Wrap())
26+
neighs = sum(nhood) - tiles[idx]
27+
if tiles[idx] && neighs < 2
28+
0
29+
elseif tiles[idx] && neighs > 3
30+
0
31+
elseif !tiles[idx] && neighs == 3
32+
1
33+
else
34+
tiles[idx]
35+
end
36+
end
37+
tiles[idx] = outputs[idx]
38+
end
39+
end
40+
heatmap(Int.(collect(outputs)))
41+
end
42+
path = mp4(anim; fps=5, show_msg=true).filename
43+
```

src/Dagger.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ include("utils/dagdebug.jl")
5050
include("utils/locked-object.jl")
5151
include("utils/tasks.jl")
5252

53-
import MacroTools: @capture
53+
import MacroTools: @capture, prewalk
54+
5455
include("options.jl")
5556
include("processor.jl")
5657
include("threadproc.jl")
@@ -76,6 +77,8 @@ include("sch/Sch.jl"); using .Sch
7677

7778
# Data dependency task queue
7879
include("datadeps.jl")
80+
include("utils/haloarray.jl")
81+
include("stencil.jl")
7982

8083
# Streaming
8184
include("stream.jl")

src/stencil.jl

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# FIXME: Remove me
2+
const Read = In
3+
const Write = Out
4+
const ReadWrite = InOut
5+
6+
function load_neighbor_edge(arr, dim, dir, neigh_dist)
7+
if dir == -1
8+
start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
9+
stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr)))
10+
elseif dir == 1
11+
start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr)))
12+
stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr)))
13+
end
14+
# FIXME: Don't collect
15+
return move(thunk_processor(), collect(@view arr[start_idx:stop_idx]))
16+
end
17+
function load_neighbor_corner(arr, corner_side, neigh_dist)
18+
start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
19+
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr)))
20+
return move(thunk_processor(), collect(@view arr[start_idx:stop_idx]))
21+
end
22+
function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
23+
@assert neigh_dist isa Integer && neigh_dist > 0 "Neighborhood distance must be an Integer greater than 0"
24+
25+
# FIXME: Depends on neigh_dist and chunk size
26+
chunk_dist = 1
27+
# Get the center
28+
accesses = Any[chunks[idx]]
29+
30+
# Get the edges
31+
for dim in 1:ndims(chunks)
32+
for dir in (-1, +1)
33+
new_idx = idx + CartesianIndex(ntuple(i -> i == dim ? dir*chunk_dist : 0, ndims(chunks)))
34+
if is_past_boundary(size(chunks), new_idx)
35+
if boundary_has_transition(boundary)
36+
new_idx = boundary_transition(boundary, new_idx, size(chunks))
37+
else
38+
new_idx = idx
39+
end
40+
chunk = chunks[new_idx]
41+
push!(accesses, Dagger.@spawn load_boundary_edge(boundary, chunk, dim, dir, neigh_dist))
42+
else
43+
chunk = chunks[new_idx]
44+
push!(accesses, Dagger.@spawn load_neighbor_edge(chunk, dim, dir, neigh_dist))
45+
end
46+
end
47+
end
48+
49+
# Get the corners
50+
for corner_num in 1:(2^ndims(chunks))
51+
corner_side = CartesianIndex(reverse(ntuple(ndims(chunks)) do i
52+
((corner_num-1) >> (((ndims(chunks) - i) + 1) - 1)) & 1
53+
end))
54+
corner_new_idx = CartesianIndex(ntuple(ndims(chunks)) do i
55+
corner_shift = iszero(corner_side[i]) ? -1 : 1
56+
return idx[i] + corner_shift
57+
end)
58+
if is_past_boundary(size(chunks), corner_new_idx)
59+
if boundary_has_transition(boundary)
60+
corner_new_idx = boundary_transition(boundary, corner_new_idx, size(chunks))
61+
else
62+
corner_new_idx = idx
63+
end
64+
chunk = chunks[corner_new_idx]
65+
push!(accesses, Dagger.@spawn load_boundary_corner(boundary, chunk, corner_side, neigh_dist))
66+
else
67+
chunk = chunks[corner_new_idx]
68+
push!(accesses, Dagger.@spawn load_neighbor_corner(chunk, corner_side, neigh_dist))
69+
end
70+
end
71+
72+
@assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))"
73+
return accesses
74+
end
75+
function build_halo(neigh_dist, boundary, center, all_neighbors...)
76+
N = ndims(center)
77+
edges = all_neighbors[1:(2*N)]
78+
corners = all_neighbors[((2^N)+1):end]
79+
@assert length(edges) == 2*N && length(corners) == 2^N "Halo mismatch: edges=$(length(edges)) corners=$(length(corners))"
80+
return HaloArray(center, (edges...,), (corners...,), ntuple(_->neigh_dist, N))
81+
end
82+
function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N}
83+
@assert all(arr.halo_width .== arr.halo_width[1])
84+
neigh_dist = arr.halo_width[1]
85+
start_idx = idx - CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
86+
stop_idx = idx + CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
87+
return @view arr[start_idx:stop_idx]
88+
end
89+
function inner_stencil!(f, output, read_vars)
90+
processor = thunk_processor()
91+
inner_stencil_proc!(processor, f, output, read_vars)
92+
end
93+
# Non-KA (for CPUs)
94+
function inner_stencil_proc!(::ThreadProc, f, output, read_vars)
95+
for idx in CartesianIndices(output)
96+
f(idx, output, read_vars)
97+
end
98+
return
99+
end
100+
101+
is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size)))
102+
103+
struct Wrap end
104+
boundary_has_transition(::Wrap) = true
105+
boundary_transition(::Wrap, idx, size) =
106+
CartesianIndex(ntuple(i -> mod1(idx[i], size[i]), length(size)))
107+
load_boundary_edge(::Wrap, arr, dim, dir, neigh_dist) = load_neighbor_edge(arr, dim, dir, neigh_dist)
108+
load_boundary_corner(::Wrap, arr, corner_side, neigh_dist) = load_neighbor_corner(arr, corner_side, neigh_dist)
109+
110+
struct Pad{T}
111+
padval::T
112+
end
113+
boundary_has_transition(::Pad) = false
114+
function load_boundary_edge(pad::Pad, arr, dim, dir, neigh_dist)
115+
if dir == -1
116+
start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
117+
stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr)))
118+
elseif dir == 1
119+
start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr)))
120+
stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr)))
121+
end
122+
edge_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr))
123+
# FIXME: return Fill(pad.padval, edge_size)
124+
return move(thunk_processor(), fill(pad.padval, edge_size))
125+
end
126+
function load_boundary_corner(pad::Pad, arr, corner_side, neigh_dist)
127+
start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
128+
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr)))
129+
corner_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr))
130+
# FIXME: return Fill(pad.padval, corner_size)
131+
return move(thunk_processor(), fill(pad.padval, corner_size))
132+
end
133+
134+
"""
135+
@stencil begin body end
136+
137+
Allows the specification of stencil operations within a `spawn_datadeps`
138+
region. The `idx` variable is used to iterate over `range`, which must be a
139+
`DArray`. An example usage may look like:
140+
141+
```julia
142+
import Dagger: @stencil, Wrap
143+
144+
A = zeros(Blocks(3, 3), Int, 9, 9)
145+
A[5, 5] = 1
146+
B = zeros(Blocks(3, 3), Int, 9, 9)
147+
Dagger.spawn_datadeps() do
148+
@stencil begin
149+
# Sum values of all neighbors with self
150+
A[idx] = sum(@neighbors(A[idx], 1, Wrap()))
151+
# Decrement all values by 1
152+
A[idx] -= 1
153+
# Copy A to B
154+
B[idx] = A[idx]
155+
end
156+
end
157+
```
158+
159+
Each expression within an `@stencil` region that performs an in-place indexing
160+
expression like `A[idx] = ...` is transformed into a set of tasks that operate
161+
on each chunk of `A` or any other arrays specified as `A[idx]`, and within each
162+
task, elements of that chunk of `A` can be accessed. Elements of multiple
163+
`DArray`s can be accessed, such as `B[idx]`, so long as `B` has the same size,
164+
shape, and chunk layout as `A`.
165+
166+
Additionally, the `@neighbors` macro can be used to access a neighborhood of
167+
values around `A[idx]`, at a configurable distance (in this case, 1 element
168+
distance) and with various kinds of boundary conditions (in this case, `Wrap()`
169+
specifies wrapping behavior on the boundaries). Neighborhoods are computed with
170+
respect to neighboring chunks as well - if a neighborhood would overflow from
171+
the current chunk into one or more neighboring chunks, values from those
172+
neighboring chunks will be included in the neighborhood.
173+
174+
Note that, while `@stencil` may look like a `for` loop, it does not follow the
175+
same semantics; in particular, an expression within `@stencil` occurs "all at
176+
once" (across all indices) before the next expression occurs. This means that
177+
`A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` will write the sum of
178+
neighbors for all `idx` values into `A[idx]` before `A[idx] -= 1` decrements
179+
the values `A` by 1, and that occurs before any of the values are copied to `B`
180+
in `B[idx] = A[idx]`. Of course, pipelining and other optimizations may still
181+
occur, so long as they respect the sequential nature of `@stencil` (just like
182+
with other operations in `spawn_datadeps`).
183+
"""
184+
macro stencil(orig_ex)
185+
@assert Meta.isexpr(orig_ex, :block) "Invalid stencil block: $orig_ex"
186+
187+
# Collect access pattern information
188+
inners = []
189+
all_accessed_vars = Set{Symbol}()
190+
for inner_ex in orig_ex.args
191+
inner_ex isa LineNumberNode && continue
192+
@assert @capture(inner_ex, write_ex_ = read_ex_) "Invalid update expression: $inner_ex"
193+
@assert @capture(write_ex, write_var_[write_idx_]) "Update expression requires a write: $write_ex"
194+
accessed_vars = Set{Symbol}()
195+
read_vars = Set{Symbol}()
196+
neighborhoods = Dict{Symbol, Tuple{Any, Any}}()
197+
push!(accessed_vars, write_var)
198+
prewalk(read_ex) do read_inner_ex
199+
if @capture(read_inner_ex, read_var_[read_idx_]) && read_idx == write_idx
200+
push!(accessed_vars, read_var)
201+
push!(read_vars, read_var)
202+
elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_))
203+
@assert read_idx == write_idx "Neighborhood access must be at the same index as the write: $read_inner_ex"
204+
push!(accessed_vars, read_var)
205+
push!(read_vars, read_var)
206+
neighborhoods[read_var] = (neigh_dist, boundary)
207+
end
208+
return read_inner_ex
209+
end
210+
union!(all_accessed_vars, accessed_vars)
211+
push!(inners, (;inner_ex, accessed_vars, write_var, write_idx, read_ex, read_vars, neighborhoods))
212+
end
213+
214+
# Codegen update functions
215+
final_ex = Expr(:block)
216+
@gensym chunk_idx
217+
for (;inner_ex, accessed_vars, write_var, write_idx, read_ex, read_vars, neighborhoods) in inners
218+
# Generate a variable for chunk access
219+
@gensym chunk_idx
220+
221+
# Generate function with transformed body
222+
@gensym inner_vars inner_index_var
223+
new_inner_ex_body = prewalk(inner_ex) do old_inner_ex
224+
if @capture(old_inner_ex, read_var_[read_idx_]) && read_idx == write_idx
225+
# Direct access
226+
if read_var == write_var
227+
return :($write_var[$inner_index_var])
228+
else
229+
return :($inner_vars.$read_var[$inner_index_var])
230+
end
231+
elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_))
232+
# Neighborhood access
233+
return :($load_neighborhood($inner_vars.$read_var, $inner_index_var))
234+
end
235+
return old_inner_ex
236+
end
237+
new_inner_f = :(($inner_index_var, $write_var, $inner_vars)->$new_inner_ex_body)
238+
new_inner_ex = quote
239+
$inner_vars = (;$(read_vars...))
240+
$inner_stencil!($new_inner_f, $write_var, $inner_vars)
241+
end
242+
inner_fn = Expr(:->, Expr(:tuple, Expr(:parameters, write_var, read_vars...)), new_inner_ex)
243+
244+
# Generate @spawn call with appropriate vars and deps
245+
deps_ex = Any[]
246+
if write_var in read_vars
247+
push!(deps_ex, Expr(:kw, write_var, :($ReadWrite($chunks($write_var)[$chunk_idx]))))
248+
else
249+
push!(deps_ex, Expr(:kw, write_var, :($Write($chunks($write_var)[$chunk_idx]))))
250+
end
251+
neighbor_copy_all_ex = Expr(:block)
252+
for read_var in read_vars
253+
if read_var in keys(neighborhoods)
254+
# Generate a neighborhood copy operation
255+
neigh_dist, boundary = neighborhoods[read_var]
256+
deps_inner_ex = Expr(:block)
257+
@gensym neighbor_copy_var
258+
push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn name="stencil_build_halo" $build_halo($neigh_dist, $boundary, map($Read, $select_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary))...)))
259+
push!(deps_ex, Expr(:kw, read_var, :($Read($neighbor_copy_var))))
260+
else
261+
push!(deps_ex, Expr(:kw, read_var, :($Read($chunks($read_var)[$chunk_idx]))))
262+
end
263+
end
264+
spawn_ex = :(Dagger.@spawn name="stencil_inner_fn" $inner_fn(;$(deps_ex...)))
265+
266+
# Generate loop
267+
push!(final_ex.args, quote
268+
for $chunk_idx in $CartesianIndices($chunks($write_var))
269+
$neighbor_copy_all_ex
270+
$spawn_ex
271+
end
272+
end)
273+
end
274+
275+
276+
return esc(final_ex)
277+
end

0 commit comments

Comments
 (0)