You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since #66873 was merged the compiler is now able to differentiate through functions with multiple results (such as functions with a differentiable inout parameter that also return a result).
Unfortunately we cannot directly ask for the pullback of these functions however due to missing implementations of valueWithPullback with inout parameters.
A potential function signature would be (for arity1):
@inlinablepublicfunc valueWithPullback<T, R>(
at x:inoutT, of f:@differentiable(reverse)(inoutT)->R)->(value:R, pullback:(R.TangentVector,inoutT.TangentVector)->Void){returnBuiltin.applyDerivative_vjp(f, x) // Currently missing Builtin
}
Currently we can get around this missing feature by making a copy of the parameter of a non inout function:
@differentiable(reverse)func square(x:inoutDouble){ // we can't directly call valueWithPullback on this function
x * x
}@differentiable(reverse)func nonInoutSquare(x:Double)->Double{varx= x
square(x: x)return x
}letresult=valueWithPullback(at:5.0, of: nonInoutSquare)
This kind of defeats the point of course in terms of expressivity and performance since we have to make additional copies here that would be avoided when directly using inout parameters.
Potential issue:
There are currently three valueWithPullback implementations from arity 1 to 3. Due to the underlying Builtins we unfortunately can't simplify these using parameter packs (as far as I can tell). Adding potential functions with inout parameters here will greatly increase the amount of overloads for all the unique combinations of parameters being "normal" or "inout" and functions having differentiable results or not. inout parameters also don't lend themselves to parameter packs at this time unfortunately (afaik).
Do people see any other potential roadblocks for this feature?
Additional information
No response
The text was updated successfully, but these errors were encountered:
Description
Since #66873 was merged the compiler is now able to differentiate through functions with multiple results (such as functions with a differentiable inout parameter that also return a result).
Unfortunately we cannot directly ask for the pullback of these functions however due to missing implementations of
valueWithPullback
with inout parameters.A potential function signature would be (for arity1):
Currently we can get around this missing feature by making a copy of the parameter of a non inout function:
This kind of defeats the point of course in terms of expressivity and performance since we have to make additional copies here that would be avoided when directly using inout parameters.
Potential issue:
There are currently three valueWithPullback implementations from arity 1 to 3. Due to the underlying Builtins we unfortunately can't simplify these using parameter packs (as far as I can tell). Adding potential functions with inout parameters here will greatly increase the amount of overloads for all the unique combinations of parameters being "normal" or "inout" and functions having differentiable results or not.
inout
parameters also don't lend themselves to parameter packs at this time unfortunately (afaik).Do people see any other potential roadblocks for this feature?
Additional information
No response
The text was updated successfully, but these errors were encountered: