From f5edcc5873c76dee6add37f608e7d4603bad4939 Mon Sep 17 00:00:00 2001 From: Pavel Shashkin <pshashkin@go-promo.ru> Date: Tue, 26 Mar 2019 21:44:09 +0300 Subject: [PATCH 1/6] grad for `reverse` --- src/lib/array.jl | 7 +++++++ test/tracker.jl | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index 82d3c819..c5284688 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -240,6 +240,13 @@ Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm) Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm) @grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing) +Base.reverse(xs::TrackedArray; dims) = track(reverse, xs, dims = dims) +@grad reverse(xs; dims) = reverse(data(xs), dims = dims), Δ -> (reverse(Δ, dims = dims), nothing) +Base.reverse(xs::TrackedVector) = track(reverse, xs) +@grad reverse(xs) = reverse(data(xs)), Δ -> (reverse(Δ),) +Base.reverse(xs::TrackedVector, start, stop) = track(reverse, xs, start, stop) +@grad reverse(xs, start, stop) = reverse(data(xs), start, stop), Δ -> (reverse(Δ, start, stop), nothing, nothing) + function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) mat1_rsh = reshape(mat1,(1,m1,1,n1)) diff --git a/test/tracker.jl b/test/tracker.jl index a5412355..052e0fb8 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -125,6 +125,10 @@ end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6)) +@test gradtest(reverse, rand(5)) +@test gradtest(x -> reverse(x, dims=2), rand(4,5,6)) +@test gradtest(x -> reverse(x, 2, 4), rand(5)) + @test gradtest(x -> repeat(x; inner=2), rand(5)) @test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) @test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) From b9cb876d8645b2b6d53eb10761f5a36f2603b670 Mon Sep 17 00:00:00 2001 From: Pavel Shashkin <pshashkin@go-promo.ru> Date: Tue, 26 Mar 2019 22:04:44 +0300 Subject: [PATCH 2/6] Lower/UpperTriangular gradients --- src/lib/array.jl | 13 ++++++++++++- test/tracker.jl | 5 ++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 82d3c819..741243e8 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -4,7 +4,7 @@ import LinearAlgebra import LinearAlgebra: inv, det, logdet, logabsdet, \, / using Statistics -using LinearAlgebra: Transpose, Adjoint, diagm, diag +using LinearAlgebra: Transpose, Adjoint, diagm, diag, UpperTriangular, LowerTriangular struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} tracker::Tracked{A} @@ -254,6 +254,17 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b) Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) +LinearAlgebra.UpperTriangular(a::TrackedMatrix) = track(UpperTriangular, a) + +@grad function UpperTriangular(a) + return collect(UpperTriangular(data(a))), Δ -> (collect(UpperTriangular(Δ)),) +end + +LinearAlgebra.LowerTriangular(a::TrackedMatrix) = track(LowerTriangular, a) + +@grad function LowerTriangular(a) + return collect(LowerTriangular(data(a))), Δ -> (collect(LowerTriangular(Δ)),) +end inv(A::TrackedArray) = Tracker.track(inv, A) @grad function inv(A) diff --git a/test/tracker.jl b/test/tracker.jl index a5412355..ff9f75b6 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -2,7 +2,7 @@ using Tracker, Test, NNlib using Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff using NNlib: conv, ∇conv_data, depthwiseconv using Printf: @sprintf -using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet +using LinearAlgebra: diagm, dot, UpperTriangular, LowerTriangular, norm, det, logdet, logabsdet using Statistics: mean, std using Random # using StatsBase @@ -135,6 +135,9 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(UpperTriangular, rand(4,4)) +@test gradtest(LowerTriangular, rand(4,4)) + @test gradtest(x -> diagm(0 => x), rand(3)) @test gradtest(W -> inv(log.(W * W)), (5,5)) From 2151cb16dbb3cd87ead32dabef15ddedc2b1e18b Mon Sep 17 00:00:00 2001 From: Pavel Shashkin <pshashkin@go-promo.ru> Date: Tue, 26 Mar 2019 22:13:31 +0300 Subject: [PATCH 3/6] remove collect --- src/lib/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 741243e8..b9f670b1 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -257,13 +257,13 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) LinearAlgebra.UpperTriangular(a::TrackedMatrix) = track(UpperTriangular, a) @grad function UpperTriangular(a) - return collect(UpperTriangular(data(a))), Δ -> (collect(UpperTriangular(Δ)),) + return UpperTriangular(data(a)), Δ -> (UpperTriangular(Δ),) end LinearAlgebra.LowerTriangular(a::TrackedMatrix) = track(LowerTriangular, a) @grad function LowerTriangular(a) - return collect(LowerTriangular(data(a))), Δ -> (collect(LowerTriangular(Δ)),) + return LowerTriangular(data(a)), Δ -> (LowerTriangular(Δ),) end inv(A::TrackedArray) = Tracker.track(inv, A) From 4b64e81d56f6700964534bbaf4e5a252ff7260e8 Mon Sep 17 00:00:00 2001 From: Mike J Innes <mike.j.innes@gmail.com> Date: Thu, 4 Apr 2019 17:02:40 +0100 Subject: [PATCH 4/6] update manifest --- Manifest.toml | 68 +++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 03df385b..174886af 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -21,6 +21,12 @@ git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" version = "0.5.3" +[[CSTParser]] +deps = ["LibGit2", "Test", "Tokenize"] +git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "0.5.2" + [[CommonSubexpressions]] deps = ["Test"] git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" @@ -29,9 +35,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "1.5.1" +version = "2.1.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -49,28 +55,28 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["Compat", "StaticArrays"] -git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7" +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "0.0.3" +version = "0.0.4" [[DiffRules]] deps = ["Random", "Test"] -git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9" +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.0.8" +version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] -git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de" +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.2" +version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[LibGit2]] @@ -87,31 +93,25 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MacroTools]] -deps = ["Compat"] -git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" +deps = ["CSTParser", "Compat", "DataStructures", "Test"] +git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.4.4" +version = "0.5.0" [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[Missings]] -deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] -git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.0" - [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f" -repo-rev = "master" +git-tree-sha1 = "e19773c07b66ee3f133dcffcd8f08701427537e7" +repo-rev = "11f840d2f397afc5bdcc2def5523e95c293a76e4" repo-url = "https://github.com/FluxML/NNlib.jl.git" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.4.3+" +version = "0.5.0+" [[NaNMath]] deps = ["Compat"] @@ -160,12 +160,6 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" - [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -178,24 +172,24 @@ version = "0.7.2" [[StaticArrays]] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] -git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.10.2" +version = "0.10.3" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[[StatsBase]] -deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.27.0" - [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[Tokenize]] +deps = ["Printf", "Test"] +git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.3" + [[URIParser]] deps = ["Test", "Unicode"] git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" @@ -203,7 +197,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] From daceb84221f22c1e1c8cf4ec4b929cd15e136b06 Mon Sep 17 00:00:00 2001 From: Pavel Shashkin <pshashkin@go-promo.ru> Date: Tue, 26 Mar 2019 22:04:44 +0300 Subject: [PATCH 5/6] Lower/UpperTriangular gradients --- src/lib/array.jl | 13 ++++++++++++- test/tracker.jl | 5 ++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index c5284688..7931d461 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -4,7 +4,7 @@ import LinearAlgebra import LinearAlgebra: inv, det, logdet, logabsdet, \, / using Statistics -using LinearAlgebra: Transpose, Adjoint, diagm, diag +using LinearAlgebra: Transpose, Adjoint, diagm, diag, UpperTriangular, LowerTriangular struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} tracker::Tracked{A} @@ -261,6 +261,17 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b) Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) +LinearAlgebra.UpperTriangular(a::TrackedMatrix) = track(UpperTriangular, a) + +@grad function UpperTriangular(a) + return collect(UpperTriangular(data(a))), Δ -> (collect(UpperTriangular(Δ)),) +end + +LinearAlgebra.LowerTriangular(a::TrackedMatrix) = track(LowerTriangular, a) + +@grad function LowerTriangular(a) + return collect(LowerTriangular(data(a))), Δ -> (collect(LowerTriangular(Δ)),) +end inv(A::TrackedArray) = Tracker.track(inv, A) @grad function inv(A) diff --git a/test/tracker.jl b/test/tracker.jl index 052e0fb8..ff5192eb 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -2,7 +2,7 @@ using Tracker, Test, NNlib using Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff using NNlib: conv, ∇conv_data, depthwiseconv using Printf: @sprintf -using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet +using LinearAlgebra: diagm, dot, UpperTriangular, LowerTriangular, norm, det, logdet, logabsdet using Statistics: mean, std using Random # using StatsBase @@ -139,6 +139,9 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(UpperTriangular, rand(4,4)) +@test gradtest(LowerTriangular, rand(4,4)) + @test gradtest(x -> diagm(0 => x), rand(3)) @test gradtest(W -> inv(log.(W * W)), (5,5)) From 07fad429149dc10284f5cd8960bd989780d94618 Mon Sep 17 00:00:00 2001 From: Pavel Shashkin <pshashkin@go-promo.ru> Date: Tue, 26 Mar 2019 22:13:31 +0300 Subject: [PATCH 6/6] remove collect --- src/lib/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 7931d461..694411d6 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -264,13 +264,13 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) LinearAlgebra.UpperTriangular(a::TrackedMatrix) = track(UpperTriangular, a) @grad function UpperTriangular(a) - return collect(UpperTriangular(data(a))), Δ -> (collect(UpperTriangular(Δ)),) + return UpperTriangular(data(a)), Δ -> (UpperTriangular(Δ),) end LinearAlgebra.LowerTriangular(a::TrackedMatrix) = track(LowerTriangular, a) @grad function LowerTriangular(a) - return collect(LowerTriangular(data(a))), Δ -> (collect(LowerTriangular(Δ)),) + return LowerTriangular(data(a)), Δ -> (LowerTriangular(Δ),) end inv(A::TrackedArray) = Tracker.track(inv, A)