-
Notifications
You must be signed in to change notification settings - Fork 93
Rules for filter
#570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rules for filter
#570
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just needs a version bump
src/rulesets/Base/array.jl
Outdated
##### `filter` | ||
##### | ||
|
||
function frule((_, xdot), ::typeof(filter), f, x::AbstractArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function frule((_, xdot), ::typeof(filter), f, x::AbstractArray) | |
function frule((_, xdot), ::typeof(filter), f, x::Union{AbstractArray, Tuple}) |
AbstractSet
would need a different rule I think, but probably not urgent
Similarly for rrule
, and maybe let's add a few tests with tuples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Dict and Set are weird, someone can add them later?
I did also add reverse(::Tuple)
as that's the same rule. Makes the diff huge due to test indent, sadly.
test/rulesets/Base/array.jl
Outdated
xt10 = Tuple(vcat(0, 1, rand(8))) # guarantee that not all > or < 0.5 | ||
|
||
# Forward | ||
@test_skip test_frule(filter, >(0.5) ⊢ NoTangent(), xt10; check_inferred=false) # check_result.jl:104 Expression: ActualPrimal === ExpectedPrimal Evaluated: NTuple{10, Float64} === NTuple{6, Float64} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are quite a few test errors, which I believe are all in CRTU not the rule. I can't mark things broken but have left them skipped. And mostly added explicit tests below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we open issues on CRTU and cross link the issue number back into a comment above the @test_skipped
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issues 222-233 :)
I can add links here. I have not made MWEs there. I guess I could run them again on a Julia version with working stack traces.
src/rulesets/Base/indexing.jl
Outdated
function rrule(::typeof(getindex), x::Tuple, ind) | ||
y = x[ind] | ||
z = map(Returns(NoTangent()), x) | ||
project = ProjectTo(x) | ||
function getindex_pullback(ȳ, ind::Integer) | ||
x̄ = Base.setindex(z, ȳ, ind) | ||
return (NoTangent(), project(x̄), NoTangent()) | ||
end | ||
function getindex_pullback(ȳ, ind) | ||
x̄ = z | ||
for (i, yi) in zip(ind, unthunk(ȳ)) | ||
x̄ = ntuple(k -> k==i ? x̄[i] + yi : x̄[k], length(z)) | ||
end | ||
return (NoTangent(), project(x̄), NoTangent()) | ||
end | ||
getindex_pullback(ȳ) = getindex_pullback(ȳ, ind) | ||
return y, getindex_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out that filter(::Tuple)
wants an rrule
for getindex
too. I'm not entirely sure this should exist, rather than being left for AD, as it's a pretty low-level operation.
The rrule for Filter could call rrule_via_ad
. Or perhaps it too need not exist, since that is also just one unrolled Base.afoldl
call, which AD could surely handle. I added it without enough coffee, did not try to benchmark with/without it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or perhaps it too need not exist
"it" here refers to using therrule
/rrule_via_ad
? I don't know whatBase.afoldl
does. Do you mean we could just not use therrule
, and write the pullback withBase.afoldl
? And the "AD could surely handle" refers to higher order derivatives?
Barring any misunderstanding It seems that deleting the rule and using rrule_via_ad
seems like the most elegant option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying things out, the rule for getindex
is not used. The rule for filter
does speed this up, although whether it's fair to time Diffractor pre-release I don't know.
julia> @btime Diffractor.gradient(x -> +(filter(iseven, x)...), $((1,2,3,4,5)));
min 18.375 μs, mean 19.967 μs (195 allocations, 9.34 KiB) # master
min 5.840 μs, mean 6.101 μs (49 allocations, 1.58 KiB) # with both rules
julia> @btime Zygote.gradient(x -> +(filter(iseven, x)...), $((1,2,3,4,5)));
min 130.041 μs, mean 135.277 μs (388 allocations, 15.22 KiB) # tagged
min 5.375 μs, mean 5.676 μs (54 allocations, 1.95 KiB) # rules
julia> @btime Diffractor.gradient(x -> x[1] + x[2], $((1,2,3,4,5)));
min 367.718 ns, mean 400.458 ns (15 allocations, 368 bytes) #
min 363.527 ns, mean 401.391 ns (15 allocations, 368 bytes) # does not call the rule
julia> @btime Zygote.gradient(x -> x[1] + x[2], $((1,2,3,4,5)));
min 1.458 ns, mean 1.548 ns (0 allocations) # with or without
The rrule
for filter(f, ::Tuple)
is a bit ugly, as it allocates an array, while filter
is some clever unrolled tuple thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying things out, the rule for getindex is not used.
Sorry, I am not sure I understand why did it break in the first place, and why is it not breaking now if the rule for getindex
is not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These give the right answer with or without rules. If the rule for filter is present it gets called, and it explicitly calls the rule for getindex. But calling getindex does not lead to AD calling the rule for that, according to println, presumably an internal rule handles it before checking CR.
The original motivation for the PR was that filter on a Vector does not work in Diffractor at the moment, since that mutates an array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, makes sense, thanks. I didn't realise that Tuples worked without the rule :) . Feel free to merge when CI passes.
Also, needs a version bump still
No description provided.