@@ -8,7 +8,7 @@ const RCR = RuleConfig{>:HasReverseMode}
8
8
@inline only_derivative (y,f:: F ,x) where F = only (only (ChainRulesCore. derivatives_given_output (y, f, x)))
9
9
10
10
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
11
- # is independent of `x`, as `_return_type ` says `Union{}` when calling is an error.
11
+ # is independent of `x`, as `return_type ` says `Union{}` when calling is an error.
12
12
struct NotaNumber <: Real end
13
13
14
14
"""
@@ -57,7 +57,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
57
57
end
58
58
59
59
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
60
- if isconcretetype (Core. Compiler. _return_type (only_derivative, Tuple{T, F, NotaNumber}))
60
+ if isconcretetype (Core. Compiler. return_type (only_derivative, Tuple{T, F, NotaNumber}))
61
61
Ω = bias_act! (σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
62
62
function bias_act!_fastback (Δ)
63
63
# Tempting to overwrite x again, but only safe if you call pullback at most once,
@@ -70,7 +70,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
70
70
71
71
# # Slower path: can't overwrite x, but can use derivatives_given_output
72
72
# # This case is WRONG and tests fail, but not sure why
73
- # elseif isconcretetype(Core.Compiler._return_type (only_derivative, Tuple{T, F, T}))
73
+ # elseif isconcretetype(Core.Compiler.return_type (only_derivative, Tuple{T, F, T}))
74
74
# Ω2 = fast_act(σ, x).(x) .+ b
75
75
# @show σ b
76
76
# function bias_act!_back2(Δ)
0 commit comments