diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 1a734850b..e3e90a124 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -136,7 +136,9 @@ aliasing(x::Transpose) = aliasing(parent(x)) aliasing(x::Adjoint) = aliasing(parent(x)) struct StridedAliasing{T,N,S} <: AbstractAliasing + base_ptr::RemotePtr{Cvoid,S} ptr::RemotePtr{Cvoid,S} + base_inds::NTuple{N,UnitRange{Int}} lengths::NTuple{N,Int} strides::NTuple{N,Int} end @@ -161,10 +163,12 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T}) where T +function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(x)), + return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + RemotePtr{Cvoid}(pointer(x)), + parentindices(x), size(x), strides(parent(x))) else # FIXME: Also ContiguousAliasing of container @@ -172,21 +176,21 @@ function aliasing(x::SubArray{T}) where T return UnknownAliasing() end end -#= TODO: Fix and enable strided aliasing optimization function will_alias(x::StridedAliasing{T,N,S}, y::StridedAliasing{T,N,S}) where {T,N,S} - # TODO: Upgrade Contiguous/StridedAlising to same number of dims + if x.base_ptr != y.base_ptr + # FIXME: Conservatively incorrect via `unsafe_wrap` and friends + return false + end + for dim in 1:N - # FIXME: Adjust ptrs to common base - x_span = MemorySpan{S}(x.ptr, sizeof(T)*x.strides[dim]) - y_span = MemorySpan{S}(y.ptr, sizeof(T)*y.strides[dim]) - @show dim x_span y_span - if !will_alias(x_span, y_span) + if ((x.base_inds[dim].stop) < (y.base_inds[dim].start) || (y.base_inds[dim].stop) < (x.base_inds[dim].start)) return false end end + return true end -=# +# FIXME: Upgrade Contiguous/StridedAlising to same number of dims struct TriangularAliasing{T,S} <: AbstractAliasing ptr::RemotePtr{Cvoid,S} diff --git a/test/datadeps.jl b/test/datadeps.jl index e4b4c811f..271e2c667 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -337,6 +337,59 @@ function test_datadeps(;args_chunks::Bool, test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end + + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) end # FIXME: Deps