diff --git a/Project.toml b/Project.toml index aeb5e90c..a6094a2b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -24,7 +24,7 @@ ArrayInterfaceStaticArrays = "0.1" ChainRulesCore = "0.10.7, 1" DocStringExtensions = "0.8, 0.9" FillArrays = "0.11, 0.12, 0.13" -GPUArrays = "8" +GPUArraysCore = "0.1" RecipesBase = "0.7, 0.8, 1.0" StaticArrays = "0.12, 1.0" ZygoteRules = "0.2" diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 106cb728..eaa9f4b2 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -29,9 +29,9 @@ include("zygote.jl") Base.show(io::IO, x::Union{ArrayPartition,AbstractVectorOfArray}) = invoke(show, Tuple{typeof(io), Any}, io, x) -import GPUArrays -Base.convert(T::Type{<:GPUArrays.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA) -ChainRulesCore.rrule(T::Type{<:GPUArrays.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ) +import GPUArraysCore +Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA) +ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ) export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray, AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples