Skip to content

Commit 5df0caf

Browse files
fix: add RecursiveArrayToolsReverseDiffExt
1 parent e3d3978 commit 5df0caf

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2424
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2627

2728
[extensions]
2829
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
2930
RecursiveArrayToolsMeasurementsExt = "Measurements"
3031
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
32+
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
3133
RecursiveArrayToolsTrackerExt = "Tracker"
3234
RecursiveArrayToolsZygoteExt = "Zygote"
3335

@@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62"
4951
Pkg = "1"
5052
Random = "1"
5153
RecipesBase = "1.1"
54+
ReverseDiff = "1.15"
5255
SafeTestsets = "0.1"
5356
SparseArrays = "1.10"
5457
StaticArrays = "1.6"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module RecursiveArrayToolsReverseDiffExt
2+
3+
using RecursiveArrayTools
4+
using ReverseDiff
5+
using Zygote: @adjoint
6+
7+
@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal})
8+
function Array_adjoint(y)
9+
VA = recursivecopy(VA)
10+
for (i, slice) in zip(eachindex(VA.u), eachslice(y, dims=ndims(y)))
11+
VA.u[i] = reshape(reduce(vcat, slice), size(VA.u[i]))
12+
end
13+
return (VA,)
14+
end
15+
return Array(VA), Array_adjoint
16+
end
17+
end # module

0 commit comments

Comments
 (0)