From 167c9eb79a52108063a61a71ebb7880436b74228 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 10 Sep 2021 14:29:50 -0400 Subject: [PATCH 1/2] don't check in Manifest.toml --- Manifest.toml | 274 -------------------------------------------------- 1 file changed, 274 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 9c8e4a39..00000000 --- a/Manifest.toml +++ /dev/null @@ -1,274 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -manifest_format = "2.0" - -[[deps.Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.1" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.ChainRules]] -deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "d88340ab502af66cfffc821e70ae72f7dbdce645" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.11.5" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "30ee06de5ff870b45c78f529a6b093b3323256a3" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.3.1" - -[[deps.Combinatorics]] -git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" -uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -version = "1.0.2" - -[[deps.Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "6071cb87be6a444ac75fdbf51b8e7273808ce62f" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.35.0" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.5.0+0" - -[[deps.DataAPI]] -git-tree-sha1 = "bec2532f8adb82005476c141ec23e921fc20971b" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.8.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.10" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.5.1" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.73.0+4" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.9.1+2" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.24.0+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "2ca267b08821e86c5ef4376cffed98a46c2cb205" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.1" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2020.7.22" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.13+7" - -[[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.1" - -[[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.12" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.StatsAPI]] -git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.0.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "8cbbc098554648c84f79a463c9ff0fd277144b6c" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.10" - -[[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"] -git-tree-sha1 = "1700b86ad59348c0f9f68ddc95117071f947072d" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "368d04a820fe069f9080ff1b432147a6203c3c89" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.5.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+1" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "3.1.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.41.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "16.2.1+1" From 97094868fbdfce37670678ef4e149355906369f8 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 10 Sep 2021 13:41:47 -0400 Subject: [PATCH 2/2] adapt to JuliaLang/julia#42125 --- src/runtime.jl | 14 ++++----- src/stage1/forward.jl | 14 ++++----- src/stage1/generated.jl | 64 ++++++++++++++++++++--------------------- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/runtime.jl b/src/runtime.jl index 0b48206d..09f6e3e6 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -1,14 +1,14 @@ using ChainRulesCore -@Base.aggressive_constprop accum(a, b) = a + b -@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b) -@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple) +@Base.constprop :aggressive accum(a, b) = a + b +@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b) +@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple) fnames = union(fieldnames(x), fieldnames(y)) gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent()) grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent()) Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...) end -@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...) -@Base.aggressive_constprop accum(a::NoTangent, b) = b -@Base.aggressive_constprop accum(a, b::NoTangent) = a -@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent() +@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...) +@Base.constprop :aggressive accum(a::NoTangent, b) = b +@Base.constprop :aggressive accum(a, b::NoTangent) = a +@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent() diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index f2eda524..3812eee7 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -134,37 +134,37 @@ end (::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...) # Special case rules for performance -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} +@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) TangentBundle{N}(getfield(primal(x), s), map(x->lifted_getfield(x, s), x.partials)) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N} +@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) TaylorBundle{N}(getfield(primal(x), s), map(y->lifted_getfield(y, s), x.coeffs)) end -@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} +@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} x.tup[primal(s)] end -@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} +@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} x.tup[Base.fieldindex(B, primal(s))] end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N} +@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N} s = primal(s) TangentBundle{N}(getfield(primal(x), s, primal(inbounds)), map(x->lifted_getfield(x, s), x.partials)) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} +@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U} +@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial) end diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 62547309..4e75945a 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -38,8 +38,8 @@ struct Protected{N} a end (p::Protected)(args...) = getfield(p, :a)(args...)[1] -@Base.aggressive_constprop (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...) -@Base.aggressive_constprop (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...) +@Base.constprop :aggressive (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...) +@Base.constprop :aggressive (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...) (::∂⃖{N})(p::Protected, args...) where {N} = error("TODO: Can we support this?") struct OpticBundle{T} @@ -94,30 +94,30 @@ end end struct ∂⃖weaveInnerOdd{N, O}; b̄; end -@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N} +@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N} @destruct c, c̄ = w.b̄(Δ...) return (c̄, c) end -@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O} +@Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O} @destruct c, c̄ = w.b̄(Δ...) return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}() end struct ∂⃖weaveInnerEven{N, O}; end -@Base.aggressive_constprop function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O} +@Base.constprop :aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O} @destruct y, ȳ = Δ′(x...) return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ) end struct ∂⃖weaveOuterOdd{N, O}; end -@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N} +@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N} return (NoTangent(), Δ′′′(Δ′′)...) end -@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O} +@Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O} @destruct α, ᾱ = Δ′′′(Δ′′) return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ) end struct ∂⃖weaveOuterEven{N, O}; ᾱ end -@Base.aggressive_constprop function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O} +@Base.constprop :aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O} return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}() end @@ -156,33 +156,33 @@ struct ∂⃖rruleB{N, O}; ᾱ; ȳ̄ ; end struct ∂⃖rruleC{N, O}; ȳ̄ ; Δ′′′; β̄ ; end struct ∂⃖rruleD{N, O}; γ̄; β̄ ; end -@Base.aggressive_constprop function (a::∂⃖rruleA{N, O})(Δ) where {N, O} +@Base.constprop :aggressive function (a::∂⃖rruleA{N, O})(Δ) where {N, O} # TODO: Is this unthunk in the right place @destruct (α, ᾱ) = a.∂(a.ȳ, unthunk(Δ)) (α, ∂⃖rruleB{N, O}(ᾱ, a.ȳ̄)) end -@Base.aggressive_constprop function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O} +@Base.constprop :aggressive function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O} @destruct ((Δ′′′, β), β̄) = b.ᾱ(Δ′) (β, ∂⃖rruleC{N, O}(b.ȳ̄, Δ′′′, β̄)) end -@Base.aggressive_constprop function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O} +@Base.constprop :aggressive function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O} @destruct (γ, γ̄) = c.ȳ̄((Δ′′, c.Δ′′′)) (Base.tail(γ), ∂⃖rruleD{N, O}(γ̄, c.β̄)) end -@Base.aggressive_constprop function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O} +@Base.constprop :aggressive function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O} (δ₁, δ₂), δ̄ = d.γ̄(ZeroTangent(), Δ⁴...) (δ₁, ∂⃖rruleA{N, O+1}(d.β̄ , δ₂, δ̄ )) end # Terminal cases -@Base.aggressive_constprop function (c::∂⃖rruleB{N, N})(Δ′...) where {N} +@Base.constprop :aggressive function (c::∂⃖rruleB{N, N})(Δ′...) where {N} @destruct (Δ′′′, β) = c.ᾱ(Δ′) (β, ∂⃖rruleC{N, N}(c.ȳ̄, Δ′′′, nothing)) end -@Base.aggressive_constprop (c::∂⃖rruleC{N, N})(Δ′′) where {N} = +@Base.constprop :aggressive (c::∂⃖rruleC{N, N})(Δ′′) where {N} = Base.tail(c.ȳ̄((Δ′′, c.Δ′′′))) (::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached") @@ -255,9 +255,9 @@ function ChainRulesCore.rrule(::KwFunc, kwargs, f, args...) end end -@Base.aggressive_constprop function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol) +@Base.constprop :aggressive function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol) getfield(s, field), let P = typeof(s) - @Base.aggressive_constprop Δ->begin + @Base.constprop :aggressive Δ->begin nt = NamedTuple{(field,)}((Δ,)) (NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent()) end @@ -265,7 +265,7 @@ end end struct ∂⃖getfield{n, f}; end -@Base.aggressive_constprop function (::∂⃖getfield{n, f})(Δ) where {n,f} +@Base.constprop :aggressive function (::∂⃖getfield{n, f})(Δ) where {n,f} if @generated return Expr(:call, tuple, NoTangent(), Expr(:call, tuple, (i == f ? :(Δ) : ZeroTangent() for i = 1:n)...), @@ -279,31 +279,31 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g) struct EvenOddOdd{O, P, F, G}; f::F; g::G; end EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g) -@Base.aggressive_constprop (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g)) -@Base.aggressive_constprop (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g)) -@Base.aggressive_constprop (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ) +@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g)) +@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g)) +@Base.constprop :aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ) -@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N} +@Base.constprop :aggressive function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N} getfield(s, field), EvenOddOdd{1, c_order(N)}( ∂⃖getfield{nfields(s), field}(), - @Base.aggressive_constprop (_, Δ, _)->getfield(Δ, field)) + @Base.constprop :aggressive (_, Δ, _)->getfield(Δ, field)) end -@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N} +@Base.constprop :aggressive function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N} getfield(s, field), EvenOddOdd{1, c_order(N)}( ∂⃖getfield{nfields(s), field}(), - @Base.aggressive_constprop (_, Δ, _)->lifted_getfield(Δ, field)) + @Base.constprop :aggressive (_, Δ, _)->lifted_getfield(Δ, field)) end function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N} getfield(s, field), let P = typeof(s) EvenOddOdd{1, c_order(N)}( - (@Base.aggressive_constprop Δ->begin + (@Base.constprop :aggressive Δ->begin nt = NamedTuple{(field,)}((Δ,)) (NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent()) end), - (@Base.aggressive_constprop (_, Δs, _)->begin + (@Base.constprop :aggressive (_, Δs, _)->begin isa(Δs, Union{ZeroTangent, NoTangent}) ? Δs : getfield(ChainRulesCore.backing(Δs), field) end)) end @@ -313,13 +313,13 @@ end function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N} getindex(a, inds...), let EvenOddOdd{1, c_order(N)}( - (@Base.aggressive_constprop Δ->begin + (@Base.constprop :aggressive Δ->begin Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...) BB = zero(a) BB[inds...] = Δ (NoTangent(), BB, map(x->NoTangent(), inds)...) end), - (@Base.aggressive_constprop (_, Δ, _)->begin + (@Base.constprop :aggressive (_, Δ, _)->begin getindex(Δ, inds...) end)) end @@ -355,15 +355,15 @@ end struct ApplyOdd{O, P}; u; ∂⃖f; end struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end -@Base.aggressive_constprop function (a::ApplyOdd{O, P})(Δ) where {O, P} +@Base.constprop :aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P} r, ∂⃖∂⃖f = a.∂⃖f(Δ) (a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f)) end -@Base.aggressive_constprop function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P} +@Base.constprop :aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P} r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...) (r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f)) end -@Base.aggressive_constprop function (a::ApplyOdd{O, O})(Δ) where {O} +@Base.constprop :aggressive function (a::ApplyOdd{O, O})(Δ) where {O} r = a.∂⃖f(Δ) a.u(r) end @@ -381,7 +381,7 @@ end Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}() end -@Base.aggressive_constprop lifted_getfield(x, s) = getfield(x, s) +@Base.constprop :aggressive lifted_getfield(x, s) = getfield(x, s) lifted_getfield(x::ZeroTangent, s) = ZeroTangent() lifted_getfield(x::NoTangent, s) = NoTangent()