Skip to content

Rules for filter #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 18, 2022
Merged

Rules for filter #570

merged 10 commits into from
Jan 18, 2022

Conversation

mcabbott
Copy link
Member

No description provided.

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just needs a version bump

##### `filter`
#####

function frule((_, xdot), ::typeof(filter), f, x::AbstractArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function frule((_, xdot), ::typeof(filter), f, x::AbstractArray)
function frule((_, xdot), ::typeof(filter), f, x::Union{AbstractArray, Tuple})

AbstractSet would need a different rule I think, but probably not urgent

Similarly for rrule, and maybe let's add a few tests with tuples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Dict and Set are weird, someone can add them later?

I did also add reverse(::Tuple) as that's the same rule. Makes the diff huge due to test indent, sadly.

xt10 = Tuple(vcat(0, 1, rand(8))) # guarantee that not all > or < 0.5

# Forward
@test_skip test_frule(filter, >(0.5) ⊢ NoTangent(), xt10; check_inferred=false) # check_result.jl:104 Expression: ActualPrimal === ExpectedPrimal Evaluated: NTuple{10, Float64} === NTuple{6, Float64}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quite a few test errors, which I believe are all in CRTU not the rule. I can't mark things broken but have left them skipped. And mostly added explicit tests below.

Copy link
Member

@oxinabox oxinabox Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we open issues on CRTU and cross link the issue number back into a comment above the @test_skipped ?

Copy link
Member Author

@mcabbott mcabbott Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issues 222-233 :)

I can add links here. I have not made MWEs there. I guess I could run them again on a Julia version with working stack traces.

Comment on lines 5 to 22
function rrule(::typeof(getindex), x::Tuple, ind)
y = x[ind]
z = map(Returns(NoTangent()), x)
project = ProjectTo(x)
function getindex_pullback(ȳ, ind::Integer)
x̄ = Base.setindex(z, ȳ, ind)
return (NoTangent(), project(x̄), NoTangent())
end
function getindex_pullback(ȳ, ind)
x̄ = z
for (i, yi) in zip(ind, unthunk(ȳ))
x̄ = ntuple(k -> k==i ? x̄[i] + yi : x̄[k], length(z))
end
return (NoTangent(), project(x̄), NoTangent())
end
getindex_pullback(ȳ) = getindex_pullback(ȳ, ind)
return y, getindex_pullback
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out that filter(::Tuple) wants an rrule for getindex too. I'm not entirely sure this should exist, rather than being left for AD, as it's a pretty low-level operation.

The rrule for Filter could call rrule_via_ad. Or perhaps it too need not exist, since that is also just one unrolled Base.afoldl call, which AD could surely handle. I added it without enough coffee, did not try to benchmark with/without it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps it too need not exist
"it" here refers to using the rrule/rrule_via_ad? I don't know what Base.afoldl does. Do you mean we could just not use the rrule, and write the pullback with Base.afoldl? And the "AD could surely handle" refers to higher order derivatives?

Barring any misunderstanding It seems that deleting the rule and using rrule_via_ad seems like the most elegant option?

Copy link
Member Author

@mcabbott mcabbott Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying things out, the rule for getindex is not used. The rule for filter does speed this up, although whether it's fair to time Diffractor pre-release I don't know.

julia> @btime Diffractor.gradient(x -> +(filter(iseven, x)...), $((1,2,3,4,5)));
  min 18.375 μs, mean 19.967 μs (195 allocations, 9.34 KiB)  # master
  min 5.840 μs, mean 6.101 μs (49 allocations, 1.58 KiB)  # with both rules

julia> @btime Zygote.gradient(x -> +(filter(iseven, x)...), $((1,2,3,4,5)));
  min 130.041 μs, mean 135.277 μs (388 allocations, 15.22 KiB)  # tagged
  min 5.375 μs, mean 5.676 μs (54 allocations, 1.95 KiB)  # rules
  
julia> @btime Diffractor.gradient(x -> x[1] + x[2], $((1,2,3,4,5)));
  min 367.718 ns, mean 400.458 ns (15 allocations, 368 bytes)  # 
  min 363.527 ns, mean 401.391 ns (15 allocations, 368 bytes)  # does not call the rule

julia> @btime Zygote.gradient(x -> x[1] + x[2], $((1,2,3,4,5)));
  min 1.458 ns, mean 1.548 ns (0 allocations)  # with or without

The rrule for filter(f, ::Tuple) is a bit ugly, as it allocates an array, while filter is some clever unrolled tuple thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying things out, the rule for getindex is not used.

Sorry, I am not sure I understand why did it break in the first place, and why is it not breaking now if the rule for getindex is not used?

Copy link
Member Author

@mcabbott mcabbott Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These give the right answer with or without rules. If the rule for filter is present it gets called, and it explicitly calls the rule for getindex. But calling getindex does not lead to AD calling the rule for that, according to println, presumably an internal rule handles it before checking CR.

The original motivation for the PR was that filter on a Vector does not work in Diffractor at the moment, since that mutates an array.

Copy link
Member

@mzgubic mzgubic Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense, thanks. I didn't realise that Tuples worked without the rule :) . Feel free to merge when CI passes.

Also, needs a version bump still

@mcabbott mcabbott merged commit b197e62 into JuliaDiff:main Jan 18, 2022
@mcabbott mcabbott deleted the filter branch January 18, 2022 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants