@@ -42,3 +42,45 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
42
42
end
43
43
A. x,literal_ArrayPartition_x_adjoint
44
44
end
45
+
46
+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i)
47
+ function AbstractVectorOfArray_getindex_adjoint (Δ)
48
+ Δ′ = [ (i == j ? Δ : zero (x)) for (x,j) in zip (VA. u, 1 : length (VA))]
49
+ (Δ′,nothing )
50
+ end
51
+ VA[i],AbstractVectorOfArray_getindex_adjoint
52
+ end
53
+
54
+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i, j... )
55
+ function AbstractVectorOfArray_getindex_adjoint (Δ)
56
+ Δ′ = zero (VA)
57
+ Δ′[i,j... ] = Δ
58
+ (Δ′, i,map (_ -> nothing , j)... )
59
+ end
60
+ VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
61
+ end
62
+
63
+ ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
64
+ function ArrayPartition_adjoint (_y)
65
+ y = Array (_y)
66
+ starts = vcat (0 ,cumsum (reduce (vcat,length .(x))))
67
+ ntuple (i -> reshape (y[starts[i]+ 1 : starts[i+ 1 ]], size (x[i])), length (x)), nothing
68
+ end
69
+
70
+ ArrayPartition (x, Val{copy_x}), ArrayPartition_adjoint
71
+ end
72
+
73
+ ZygoteRules. @adjoint function VectorOfArray (u)
74
+ VectorOfArray (u),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],)
75
+ end
76
+
77
+ ZygoteRules. @adjoint function DiffEqArray (u,t)
78
+ DiffEqArray (u,t),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],nothing )
79
+ end
80
+
81
+ ZygoteRules. @adjoint function ZygoteRules. literal_getproperty (A:: ArrayPartition , :: Val{:x} )
82
+ function literal_ArrayPartition_x_adjoint (d)
83
+ (ArrayPartition ((isnothing (d[i]) ? zero (A. x[i]) : d[i] for i in 1 : length (d)). .. ),)
84
+ end
85
+ A. x,literal_ArrayPartition_x_adjoint
86
+ end
0 commit comments