|
| 1 | +# FIXME: Remove me |
| 2 | +const Read = In |
| 3 | +const Write = Out |
| 4 | +const ReadWrite = InOut |
| 5 | + |
| 6 | +function load_neighbor_edge(arr, dim, dir, neigh_dist) |
| 7 | + if dir == -1 |
| 8 | + start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) |
| 9 | + stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr))) |
| 10 | + elseif dir == 1 |
| 11 | + start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr))) |
| 12 | + stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr))) |
| 13 | + end |
| 14 | + # FIXME: Don't collect |
| 15 | + return move(thunk_processor(), collect(@view arr[start_idx:stop_idx])) |
| 16 | +end |
| 17 | +function load_neighbor_corner(arr, corner_side, neigh_dist) |
| 18 | + start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) |
| 19 | + stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr))) |
| 20 | + return move(thunk_processor(), collect(@view arr[start_idx:stop_idx])) |
| 21 | +end |
| 22 | +function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary) |
| 23 | + @assert neigh_dist isa Integer && neigh_dist > 0 "Neighborhood distance must be an Integer greater than 0" |
| 24 | + |
| 25 | + # FIXME: Depends on neigh_dist and chunk size |
| 26 | + chunk_dist = 1 |
| 27 | + # Get the center |
| 28 | + accesses = Any[chunks[idx]] |
| 29 | + |
| 30 | + # Get the edges |
| 31 | + for dim in 1:ndims(chunks) |
| 32 | + for dir in (-1, +1) |
| 33 | + new_idx = idx + CartesianIndex(ntuple(i -> i == dim ? dir*chunk_dist : 0, ndims(chunks))) |
| 34 | + if is_past_boundary(size(chunks), new_idx) |
| 35 | + if boundary_has_transition(boundary) |
| 36 | + new_idx = boundary_transition(boundary, new_idx, size(chunks)) |
| 37 | + else |
| 38 | + new_idx = idx |
| 39 | + end |
| 40 | + chunk = chunks[new_idx] |
| 41 | + push!(accesses, Dagger.@spawn load_boundary_edge(boundary, chunk, dim, dir, neigh_dist)) |
| 42 | + else |
| 43 | + chunk = chunks[new_idx] |
| 44 | + push!(accesses, Dagger.@spawn load_neighbor_edge(chunk, dim, dir, neigh_dist)) |
| 45 | + end |
| 46 | + end |
| 47 | + end |
| 48 | + |
| 49 | + # Get the corners |
| 50 | + for corner_num in 1:(2^ndims(chunks)) |
| 51 | + corner_side = CartesianIndex(reverse(ntuple(ndims(chunks)) do i |
| 52 | + ((corner_num-1) >> (((ndims(chunks) - i) + 1) - 1)) & 1 |
| 53 | + end)) |
| 54 | + corner_new_idx = CartesianIndex(ntuple(ndims(chunks)) do i |
| 55 | + corner_shift = iszero(corner_side[i]) ? -1 : 1 |
| 56 | + return idx[i] + corner_shift |
| 57 | + end) |
| 58 | + if is_past_boundary(size(chunks), corner_new_idx) |
| 59 | + if boundary_has_transition(boundary) |
| 60 | + corner_new_idx = boundary_transition(boundary, corner_new_idx, size(chunks)) |
| 61 | + else |
| 62 | + corner_new_idx = idx |
| 63 | + end |
| 64 | + chunk = chunks[corner_new_idx] |
| 65 | + push!(accesses, Dagger.@spawn load_boundary_corner(boundary, chunk, corner_side, neigh_dist)) |
| 66 | + else |
| 67 | + chunk = chunks[corner_new_idx] |
| 68 | + push!(accesses, Dagger.@spawn load_neighbor_corner(chunk, corner_side, neigh_dist)) |
| 69 | + end |
| 70 | + end |
| 71 | + |
| 72 | + @assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))" |
| 73 | + return accesses |
| 74 | +end |
| 75 | +function build_halo(neigh_dist, boundary, center, all_neighbors...) |
| 76 | + N = ndims(center) |
| 77 | + edges = all_neighbors[1:(2*N)] |
| 78 | + corners = all_neighbors[((2^N)+1):end] |
| 79 | + @assert length(edges) == 2*N && length(corners) == 2^N "Halo mismatch: edges=$(length(edges)) corners=$(length(corners))" |
| 80 | + return HaloArray(center, (edges...,), (corners...,), ntuple(_->neigh_dist, N)) |
| 81 | +end |
| 82 | +function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N} |
| 83 | + @assert all(arr.halo_width .== arr.halo_width[1]) |
| 84 | + neigh_dist = arr.halo_width[1] |
| 85 | + start_idx = idx - CartesianIndex(ntuple(_->neigh_dist, ndims(arr))) |
| 86 | + stop_idx = idx + CartesianIndex(ntuple(_->neigh_dist, ndims(arr))) |
| 87 | + return @view arr[start_idx:stop_idx] |
| 88 | +end |
| 89 | +function inner_stencil!(f, output, read_vars) |
| 90 | + processor = thunk_processor() |
| 91 | + inner_stencil_proc!(processor, f, output, read_vars) |
| 92 | +end |
| 93 | +# Non-KA (for CPUs) |
| 94 | +function inner_stencil_proc!(::ThreadProc, f, output, read_vars) |
| 95 | + for idx in CartesianIndices(output) |
| 96 | + f(idx, output, read_vars) |
| 97 | + end |
| 98 | + return |
| 99 | +end |
| 100 | + |
| 101 | +is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size))) |
| 102 | + |
| 103 | +struct Wrap end |
| 104 | +boundary_has_transition(::Wrap) = true |
| 105 | +boundary_transition(::Wrap, idx, size) = |
| 106 | + CartesianIndex(ntuple(i -> mod1(idx[i], size[i]), length(size))) |
| 107 | +load_boundary_edge(::Wrap, arr, dim, dir, neigh_dist) = load_neighbor_edge(arr, dim, dir, neigh_dist) |
| 108 | +load_boundary_corner(::Wrap, arr, corner_side, neigh_dist) = load_neighbor_corner(arr, corner_side, neigh_dist) |
| 109 | + |
| 110 | +struct Pad{T} |
| 111 | + padval::T |
| 112 | +end |
| 113 | +boundary_has_transition(::Pad) = false |
| 114 | +function load_boundary_edge(pad::Pad, arr, dim, dir, neigh_dist) |
| 115 | + if dir == -1 |
| 116 | + start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) |
| 117 | + stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr))) |
| 118 | + elseif dir == 1 |
| 119 | + start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr))) |
| 120 | + stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr))) |
| 121 | + end |
| 122 | + edge_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr)) |
| 123 | + # FIXME: return Fill(pad.padval, edge_size) |
| 124 | + return move(thunk_processor(), fill(pad.padval, edge_size)) |
| 125 | +end |
| 126 | +function load_boundary_corner(pad::Pad, arr, corner_side, neigh_dist) |
| 127 | + start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) |
| 128 | + stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr))) |
| 129 | + corner_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr)) |
| 130 | + # FIXME: return Fill(pad.padval, corner_size) |
| 131 | + return move(thunk_processor(), fill(pad.padval, corner_size)) |
| 132 | +end |
| 133 | + |
| 134 | +""" |
| 135 | + @stencil begin body end |
| 136 | +
|
| 137 | +Allows the specification of stencil operations within a `spawn_datadeps` |
| 138 | +region. The `idx` variable is used to iterate over `range`, which must be a |
| 139 | +`DArray`. An example usage may look like: |
| 140 | +
|
| 141 | +```julia |
| 142 | +import Dagger: @stencil, Wrap |
| 143 | +
|
| 144 | +A = zeros(Blocks(3, 3), Int, 9, 9) |
| 145 | +A[5, 5] = 1 |
| 146 | +B = zeros(Blocks(3, 3), Int, 9, 9) |
| 147 | +Dagger.spawn_datadeps() do |
| 148 | + @stencil begin |
| 149 | + # Sum values of all neighbors with self |
| 150 | + A[idx] = sum(@neighbors(A[idx], 1, Wrap())) |
| 151 | + # Decrement all values by 1 |
| 152 | + A[idx] -= 1 |
| 153 | + # Copy A to B |
| 154 | + B[idx] = A[idx] |
| 155 | + end |
| 156 | +end |
| 157 | +``` |
| 158 | +
|
| 159 | +Each expression within an `@stencil` region that performs an in-place indexing |
| 160 | +expression like `A[idx] = ...` is transformed into a set of tasks that operate |
| 161 | +on each chunk of `A` or any other arrays specified as `A[idx]`, and within each |
| 162 | +task, elements of that chunk of `A` can be accessed. Elements of multiple |
| 163 | +`DArray`s can be accessed, such as `B[idx]`, so long as `B` has the same size, |
| 164 | +shape, and chunk layout as `A`. |
| 165 | +
|
| 166 | +Additionally, the `@neighbors` macro can be used to access a neighborhood of |
| 167 | +values around `A[idx]`, at a configurable distance (in this case, 1 element |
| 168 | +distance) and with various kinds of boundary conditions (in this case, `Wrap()` |
| 169 | +specifies wrapping behavior on the boundaries). Neighborhoods are computed with |
| 170 | +respect to neighboring chunks as well - if a neighborhood would overflow from |
| 171 | +the current chunk into one or more neighboring chunks, values from those |
| 172 | +neighboring chunks will be included in the neighborhood. |
| 173 | +
|
| 174 | +Note that, while `@stencil` may look like a `for` loop, it does not follow the |
| 175 | +same semantics; in particular, an expression within `@stencil` occurs "all at |
| 176 | +once" (across all indices) before the next expression occurs. This means that |
| 177 | +`A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` will write the sum of |
| 178 | +neighbors for all `idx` values into `A[idx]` before `A[idx] -= 1` decrements |
| 179 | +the values `A` by 1, and that occurs before any of the values are copied to `B` |
| 180 | +in `B[idx] = A[idx]`. Of course, pipelining and other optimizations may still |
| 181 | +occur, so long as they respect the sequential nature of `@stencil` (just like |
| 182 | +with other operations in `spawn_datadeps`). |
| 183 | +""" |
| 184 | +macro stencil(orig_ex) |
| 185 | + @assert Meta.isexpr(orig_ex, :block) "Invalid stencil block: $orig_ex" |
| 186 | + |
| 187 | + # Collect access pattern information |
| 188 | + inners = [] |
| 189 | + all_accessed_vars = Set{Symbol}() |
| 190 | + for inner_ex in orig_ex.args |
| 191 | + inner_ex isa LineNumberNode && continue |
| 192 | + @assert @capture(inner_ex, write_ex_ = read_ex_) "Invalid update expression: $inner_ex" |
| 193 | + @assert @capture(write_ex, write_var_[write_idx_]) "Update expression requires a write: $write_ex" |
| 194 | + accessed_vars = Set{Symbol}() |
| 195 | + read_vars = Set{Symbol}() |
| 196 | + neighborhoods = Dict{Symbol, Tuple{Any, Any}}() |
| 197 | + push!(accessed_vars, write_var) |
| 198 | + prewalk(read_ex) do read_inner_ex |
| 199 | + if @capture(read_inner_ex, read_var_[read_idx_]) && read_idx == write_idx |
| 200 | + push!(accessed_vars, read_var) |
| 201 | + push!(read_vars, read_var) |
| 202 | + elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) |
| 203 | + @assert read_idx == write_idx "Neighborhood access must be at the same index as the write: $read_inner_ex" |
| 204 | + push!(accessed_vars, read_var) |
| 205 | + push!(read_vars, read_var) |
| 206 | + neighborhoods[read_var] = (neigh_dist, boundary) |
| 207 | + end |
| 208 | + return read_inner_ex |
| 209 | + end |
| 210 | + union!(all_accessed_vars, accessed_vars) |
| 211 | + push!(inners, (;inner_ex, accessed_vars, write_var, write_idx, read_ex, read_vars, neighborhoods)) |
| 212 | + end |
| 213 | + |
| 214 | + # Codegen update functions |
| 215 | + final_ex = Expr(:block) |
| 216 | + @gensym chunk_idx |
| 217 | + for (;inner_ex, accessed_vars, write_var, write_idx, read_ex, read_vars, neighborhoods) in inners |
| 218 | + # Generate a variable for chunk access |
| 219 | + @gensym chunk_idx |
| 220 | + |
| 221 | + # Generate function with transformed body |
| 222 | + @gensym inner_vars inner_index_var |
| 223 | + new_inner_ex_body = prewalk(inner_ex) do old_inner_ex |
| 224 | + if @capture(old_inner_ex, read_var_[read_idx_]) && read_idx == write_idx |
| 225 | + # Direct access |
| 226 | + if read_var == write_var |
| 227 | + return :($write_var[$inner_index_var]) |
| 228 | + else |
| 229 | + return :($inner_vars.$read_var[$inner_index_var]) |
| 230 | + end |
| 231 | + elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) |
| 232 | + # Neighborhood access |
| 233 | + return :($load_neighborhood($inner_vars.$read_var, $inner_index_var)) |
| 234 | + end |
| 235 | + return old_inner_ex |
| 236 | + end |
| 237 | + new_inner_f = :(($inner_index_var, $write_var, $inner_vars)->$new_inner_ex_body) |
| 238 | + new_inner_ex = quote |
| 239 | + $inner_vars = (;$(read_vars...)) |
| 240 | + $inner_stencil!($new_inner_f, $write_var, $inner_vars) |
| 241 | + end |
| 242 | + inner_fn = Expr(:->, Expr(:tuple, Expr(:parameters, write_var, read_vars...)), new_inner_ex) |
| 243 | + |
| 244 | + # Generate @spawn call with appropriate vars and deps |
| 245 | + deps_ex = Any[] |
| 246 | + if write_var in read_vars |
| 247 | + push!(deps_ex, Expr(:kw, write_var, :($ReadWrite($chunks($write_var)[$chunk_idx])))) |
| 248 | + else |
| 249 | + push!(deps_ex, Expr(:kw, write_var, :($Write($chunks($write_var)[$chunk_idx])))) |
| 250 | + end |
| 251 | + neighbor_copy_all_ex = Expr(:block) |
| 252 | + for read_var in read_vars |
| 253 | + if read_var in keys(neighborhoods) |
| 254 | + # Generate a neighborhood copy operation |
| 255 | + neigh_dist, boundary = neighborhoods[read_var] |
| 256 | + deps_inner_ex = Expr(:block) |
| 257 | + @gensym neighbor_copy_var |
| 258 | + push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn name="stencil_build_halo" $build_halo($neigh_dist, $boundary, map($Read, $select_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary))...))) |
| 259 | + push!(deps_ex, Expr(:kw, read_var, :($Read($neighbor_copy_var)))) |
| 260 | + else |
| 261 | + push!(deps_ex, Expr(:kw, read_var, :($Read($chunks($read_var)[$chunk_idx])))) |
| 262 | + end |
| 263 | + end |
| 264 | + spawn_ex = :(Dagger.@spawn name="stencil_inner_fn" $inner_fn(;$(deps_ex...))) |
| 265 | + |
| 266 | + # Generate loop |
| 267 | + push!(final_ex.args, quote |
| 268 | + for $chunk_idx in $CartesianIndices($chunks($write_var)) |
| 269 | + $neighbor_copy_all_ex |
| 270 | + $spawn_ex |
| 271 | + end |
| 272 | + end) |
| 273 | + end |
| 274 | + |
| 275 | + |
| 276 | + return esc(final_ex) |
| 277 | +end |
0 commit comments