Skip to content

Rules for filter #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.19"
version = "1.20"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.11"
ChainRulesCore = "1.11.5"
ChainRulesTestUtils = "1"
Compat = "3.35"
FiniteDifferences = "0.12.20"
Expand Down
25 changes: 22 additions & 3 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,11 @@ end
# 1-dim case allows start/stop, N-dim case takes dims keyword
# whose defaults changed in Julia 1.6... just pass them all through:

function frule((_, xdot), ::typeof(reverse), x::AbstractArray, args...; kw...)
function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
end

function rrule(::typeof(reverse), x::AbstractArray, args...; kw...)
function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
nots = map(Returns(NoTangent()), args)
function reverse_pullback(dy)
dx = @thunk reverse(unthunk(dy), args...; kw...)
Expand Down Expand Up @@ -360,12 +360,31 @@ function frule((_, xdot), ::typeof(fill), x::Any, dims...)
end

function rrule(::typeof(fill), x::Any, dims...)
project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
project = ProjectTo(x)
nots = map(Returns(NoTangent()), dims)
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
return fill(x, dims...), fill_pullback
end

#####
##### `filter`
#####

function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray)
inds = findall(f, x)
return x[inds], xdot[inds]
end

function rrule(::typeof(filter), f, x::AbstractArray)
inds = findall(f, x)
y, back = rrule(getindex, x, inds)
function filter_pullback(dy)
_, dx, _ = back(dy)
return (NoTangent(), NoTangent(), dx)
end
return y, filter_pullback
end

#####
##### `findmax`, `maximum`, etc.
#####
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)

return y, getindex_pullback
end

77 changes: 52 additions & 25 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,38 @@ end
end

@testset "reverse" begin
# Forward
test_frule(reverse, rand(5))
test_frule(reverse, rand(5), 2, 4)
test_frule(reverse, rand(5), fkwargs=(dims=1,))

test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
if VERSION >= v"1.6"
test_frule(reverse, rand(3,4))
test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
@testset "Tuple" begin
test_frule(reverse, Tuple(rand(10)))
@test_skip test_rrule(reverse, Tuple(rand(10))) # Ambiguity in isapprox, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/229
end

# Reverse
test_rrule(reverse, rand(5))
test_rrule(reverse, rand(5), 2, 4)
test_rrule(reverse, rand(5), fkwargs=(dims=1,))

test_rrule(reverse, rand(3,4), fkwargs=(dims=2,))
if VERSION >= v"1.6"
test_rrule(reverse, rand(3,4))
test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))

# Structured
y, pb = rrule(reverse, Diagonal([1,2,3]))
# We only preserve structure in this case if given structured tangent (no ProjectTo)
@test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal
@test unthunk(pb(rand(3, 3))[2]) isa AbstractArray
@testset "Array" begin
# Forward
test_frule(reverse, rand(5))
test_frule(reverse, rand(5), 2, 4)
test_frule(reverse, rand(5), fkwargs=(dims=1,))

test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
if VERSION >= v"1.6"
test_frule(reverse, rand(3,4))
test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
end

# Reverse
test_rrule(reverse, rand(5))
test_rrule(reverse, rand(5), 2, 4)
test_rrule(reverse, rand(5), fkwargs=(dims=1,))

test_rrule(reverse, rand(3,4), fkwargs=(dims=2,))
if VERSION >= v"1.6"
test_rrule(reverse, rand(3,4))
test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))

# Structured
y, pb = rrule(reverse, Diagonal([1,2,3]))
# We only preserve structure in this case if given structured tangent (no ProjectTo)
@test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal
@test unthunk(pb(rand(3, 3))[2]) isa AbstractArray
end
end
end

Expand All @@ -215,6 +221,27 @@ end
test_rrule(fill, 3.3, (3, 3, 3))
end

@testset "filter" begin
@testset "Array" begin
# Random numbers will confuse finite differencing here, as it may perturb across the boundary.
x5 = [0.0, 1.0, 0.2, 0.9, 0.7]
x34 = Float64[-113 124 -37 12
96 -89 103 119
91 -21 -110 10]

# Forward
test_frule(filter, >(0.5) ⊢ NoTangent(), x5)
test_frule(filter, <(0), x34)
test_frule(filter, >(100), x5)

# Reverse
test_rrule(filter, >(0.5) ⊢ NoTangent(), x5) # Without ⊢, MethodError: zero(::Base.Fix2{typeof(>), Float64}) -- https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/231
test_rrule(filter, <(0), x34)
test_rrule(filter, >(100), x5) # fixed in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/534
@test unthunk(rrule(filter, >(100), x5)[2](Int[])[3]) == zero(x5)
end
end

@testset "findmin & findmax" begin
# Forward
test_frule(findmin, rand(10))
Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testset "getindex" begin
@testset "getindex(::Matrix{<:Number},...)" begin
@testset "getindex(::Matrix{<:Number}, ...)" begin
x = [1.0 2.0 3.0; 10.0 20.0 30.0]

@testset "single element" begin
Expand Down