-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Incompatible signature of differential operators for functions with inout parameters #67818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Tagging @rxwei @BradLarson @dan-zheng Any decent ideas how the issue could be solved in general? |
My understanding (from asking @rxwei a long time ago):
Is there really a use case for directly applying a differential operator to an |
Some examples below. Sorry if the code doesn't actually compile.
func foo(_ x: Float, _ weight: Float, _ bias: Float, useBias: Bool) -> Float {
if useBias {
return x * weight + bias
}
return x * weight
}
// This isn't supported – and that's intended.
gradient(of: foo)
// Do this instead:
gradient(of: { x, weight, bias in foo(x, weight, bias, useBias: true })
gradient(of: { x, weight, bias in foo(x, weight, bias, useBias: false })
func swap(_ x: inout Float, _ y: inout Float) {
var tmp = x
x = y
y = tmp
}
// This isn't supported – and that's intended, with the current design.
gradient(of: swap)
// This is supported: functionally-typed wrapper closure.
// This is highly contrived.
// Usually `inout` functions are used in the guts of differentiable functions.
// Not directly "differentiated" like this.
gradient(of: { x, y in
var newX = x
var newY = y
swap(&newX, &newY)
return newX
}) |
Thanks @dan-zheng this is very helpful and clears some of the concerns, though it seems we'd need to figure out what to do on more "fun" cases. Consider more elaborated code: class Class: Differentiable {
@differentiable(reverse, wrt: (self, x))
func f(_ x: Float) -> Float { // do something weird mutating the internals as well }
}
func test<C: Class>(_ c: C, _ x: Float) {
_ = gradient(at: c, x) { c, x in c.f(x) } // or pullback
}
So, looks like here we'd essentially need to extend the typechecker to disallow non-value-typed arguments in general? |
My understanding of your question:
If my understanding is correct:
|
Well, my point is they are not orthogonal if we look onto differential operators from other angle. Essentially, they expect that for the
I think we already have plenty of those, e.g. #55542 and lots of fixmes in the tests / code :) |
I see, that's a fair point.
My initial reaction: yes, further constraining differential operator types (to disallow non-sensible cases like class-typed arguments) makes sense to me. Other folks might also have different ideas. +1 from me for being use-case-driven in the differential operator design and working to support real non-toy use cases (I'm not sure what these are for class differentiation). |
My main usecase here is quite simple. And this probably would cover 80% of sane usecases: Array. Maybe Dictionary as well, but with much less footprint. |
I'm not sure I understood your comment above, since How about constraining differential operators to disallow functions with class arguments (i.e. |
I'm not sure I understood your comment above, since Array and Dictionary are value-semantics structs, not classes.
Yes, this is something I also thought about. Will try experimenting with it. Thanks! |
@JaapWijnen Found this issue with some additional things |
Consider the following code:
This produces the following set of errors:
And indeed,
pullback(at:of:)
is defined asThere is no way the pullback of
test
could match the required signature here as the pullback type is(inout Float) -> Void
here. Differential operators such aspullback
orvalueWithPullback
accidentally work in more complex cases simply becauseR
got matched to something else (e.g. if there are other function results), causing, however, potential subtle issues here and there.The main invariant that is violated here is that the pullback type for
(T) -> R
is(R.TangentVector) -> T.TangentVector
which is not quite true when inouts are involved. It will become even more untrue with class values differentiation, but for now let's mention already-supported inouts.The text was updated successfully, but these errors were encountered: