Skip to content

Commit c306b00

Browse files
Merge pull request #179 from SciML/colon
Fix all colon dispatch for GPUs
2 parents b8ebaca + 8f5f185 commit c306b00

File tree

8 files changed

+58
-5
lines changed

8 files changed

+58
-5
lines changed

.buildkite/pipeline.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
steps:
2+
- label: "GPU"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1"
6+
- JuliaCI/julia-test#v1:
7+
coverage: false # 1000x slowdown
8+
agents:
9+
queue: "juliagpu"
10+
cuda: "*"
11+
env:
12+
GROUP: 'GPU'
13+
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
14+
# SECRET_CODECOV_TOKEN: "..."
15+
timeout_in_minutes: 30
16+
# Don't run Buildkite if the commit message includes the text [skip tests]
17+
if: build.message !~ /\[skip tests\]/

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
matrix:
1414
group:
1515
- Core
16+
- Downstream
1617
version:
1718
- '1'
1819
- '1.6'

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
44
version = "2.21.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Requires, RecipesBase, StaticArrays, Statistics,
1010

1111
import ChainRulesCore
1212
import ChainRulesCore: NoTangent
13-
import ZygoteRules
13+
import ZygoteRules, Adapt
1414

1515
using FillArrays
1616

src/vector_of_array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},
6565
I::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) where {T, N}
6666
RecursiveArrayTools.VectorOfArray(A.u)[I...]
6767
end
68+
69+
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
70+
I::Colon...) where {T, N}
71+
@assert length(I) == ndims(A.u[1])+1
72+
vecs = vec.(A.u)
73+
return Adapt.adapt(__parameterless_type(T),reshape(reduce(hcat,vecs),size(A.u[1])...,length(A.u)))
74+
end
75+
6876
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Colon) where {T, N} = [A.u[j][i] for j in 1:length(A)]
6977
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
7078
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]

test/gpu/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

test/gpu/vectorofarray_gpu.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using RecursiveArrayTools, CUDA
2+
CUDA.allowscalar(false)
3+
4+
x = zeros(5)
5+
y = VectorOfArray([x,x,x])
6+
y[:,:]
7+
8+
x = CUDA.zeros(5)
9+
y = VectorOfArray([x,x,x])
10+
y[:,:]

test/runtests.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@ function activate_downstream_env()
1111
Pkg.instantiate()
1212
end
1313

14+
function activate_gpu_env()
15+
Pkg.activate("gpu")
16+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
17+
Pkg.instantiate()
18+
end
19+
1420
@time begin
21+
22+
if !is_APPVEYOR && GROUP == "Core"
1523
@time @testset "Utils Tests" begin include("utils_test.jl") end
1624
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
1725
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@@ -20,9 +28,15 @@ end
2028
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
2129
@time @testset "Upstream Tests" begin include("upstream.jl") end
2230
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
31+
end
32+
33+
if !is_APPVEYOR && GROUP == "Downstream"
34+
activate_downstream_env()
35+
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
36+
end
2337

24-
if !is_APPVEYOR && GROUP == "Downstream"
25-
activate_downstream_env()
26-
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
27-
end
38+
if !is_APPVEYOR && GROUP == "GPU"
39+
activate_gpu_env()
40+
@time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
41+
end
2842
end

0 commit comments

Comments
 (0)