Skip to content

Incompatible signature of differential operators for functions with inout parameters #67818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
asl opened this issue Aug 9, 2023 · 11 comments
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.

Comments

@asl
Copy link
Contributor

asl commented Aug 9, 2023

Consider the following code:

import _Differentiation

@differentiable(reverse, wrt : f)
func test(_ f : inout Float) -> Void {
  f *= 42
}

let _ = pullback(at: 1, of: test)

This produces the following set of errors:

inout_parameters.swift:21:29: error: cannot convert value of type '(inout Float) -> Void' to expected argument type '@differentiable(reverse) (Float) -> Void'
let _ = pullback(at: 1, of: test)
                            ^
inout_parameters.swift:21:9: error: type 'Void' cannot conform to 'Differentiable'
let _ = pullback(at: 1, of: test)
        ^
inout_parameters.swift:21:9: note: only concrete types such as structs, enums and classes can conform to protocols
let _ = pullback(at: 1, of: test)
        ^
inout_parameters.swift:21:9: note: required by global function 'pullback(at:of:)' where 'R' = 'Void'
let _ = pullback(at: 1, of: test)
        ^
error: fatalError

And indeed, pullback(at:of:) is defined as

@inlinable
public func pullback<T, R>(
  at x: T, of f: @differentiable(reverse) (T) -> R
) -> (R.TangentVector) -> T.TangentVector {
  return Builtin.applyDerivative_vjp(f, x).1
}

There is no way the pullback of test could match the required signature here as the pullback type is (inout Float) -> Void here. Differential operators such as pullback or valueWithPullback accidentally work in more complex cases simply because R got matched to something else (e.g. if there are other function results), causing, however, potential subtle issues here and there.

The main invariant that is violated here is that the pullback type for (T) -> R is (R.TangentVector) -> T.TangentVector which is not quite true when inouts are involved. It will become even more untrue with class values differentiation, but for now let's mention already-supported inouts.

@asl asl added bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. AutoDiff labels Aug 9, 2023
@asl
Copy link
Contributor Author

asl commented Aug 9, 2023

Tagging @rxwei @BradLarson @dan-zheng

Any decent ideas how the issue could be solved in general?

@dan-zheng
Copy link
Contributor

My understanding (from asking @rxwei a long time ago):

  • There's no need to define differential operators for every potential differentiable function type signature.
    • Differential operators not only cannot accept functions with inout parameters, they also cannot accept @differentiable-typed functions with nondiff parameters.
  • Instead, we define differential operators only for "functional" function types where all parameters are differentiable parameters (i.e. wrt).
    • This keeps the API surface small and understandable.
  • If users want to apply a differential operator to function-typed values that are not directly supported (e.g. function values with inout parameters or nondiff parameters), they can form a small "functional" wrapper closure and apply the differential operator to it.

Is there really a use case for directly applying a differential operator to an inout-parameter-function? I think differential operators are designed to be called on "top-level" functionally-typed functions, and the differentiation transform does the work of supporting differentiation through all the guts (inout parameters and nondiff parameters).

@dan-zheng
Copy link
Contributor

  • If users want to apply a differential operator to function-typed values that are not directly supported (e.g. function values with inout parameters or nondiff parameters), they can form a small "functional" wrapper closure and apply the differential operator to it.

Some examples below. Sorry if the code doesn't actually compile.


nondiff (non-wrt) parameter example:

func foo(_ x: Float, _ weight: Float, _ bias: Float, useBias: Bool) -> Float {
  if useBias {
    return x * weight + bias
  }
  return x * weight
}

// This isn't supported – and that's intended.
gradient(of: foo)

// Do this instead:
gradient(of: { x, weight, bias in foo(x, weight, bias, useBias: true })
gradient(of: { x, weight, bias in foo(x, weight, bias, useBias: false })

inout parameter example:

func swap(_ x: inout Float, _ y: inout Float) {
  var tmp = x
  x = y
  y = tmp
}

// This isn't supported – and that's intended, with the current design.
gradient(of: swap)

// This is supported: functionally-typed wrapper closure.
// This is highly contrived.
// Usually `inout` functions are used in the guts of differentiable functions.
// Not directly "differentiated" like this.
gradient(of: { x, y in
  var newX = x
  var newY = y
  swap(&newX, &newY)
  return newX
})

@asl
Copy link
Contributor Author

asl commented Aug 9, 2023

Thanks @dan-zheng this is very helpful and clears some of the concerns, though it seems we'd need to figure out what to do on more "fun" cases.

Consider more elaborated code:

class Class: Differentiable {
  @differentiable(reverse, wrt: (self, x))
  func f(_ x: Float) -> Float { // do something weird mutating the internals as well }
}

func test<C: Class>(_ c: C, _ x: Float) {
  _ = gradient(at: c, x) { c, x in c.f(x) } // or pullback
}

So, looks like here we'd essentially need to extend the typechecker to disallow non-value-typed arguments in general?

@dan-zheng
Copy link
Contributor

Consider more elaborated code:
So, looks like here we'd essentially need to extend the typechecker to disallow non-value-typed arguments in general?

My understanding of your question:

  • Your example and question above are about "class differentiation" (differentiation wrt class-typed values).
  • "Class differentiation" is interesting because class mutation does not require inout annotations, functions can mutate class arguments even if they are not marked as inout.

If my understanding is correct:

  • I think "class differentiation" and "the differential operator suite" are orthogonal topics that can be discussed separately.
  • If you wanted to nail down useful semantics for class differentiation (e.g. whether to infer class-typed arguments as always implicitly being inout), perhaps you could open a separate issue where we could discuss.

@asl
Copy link
Contributor Author

asl commented Aug 9, 2023

Well, my point is they are not orthogonal if we look onto differential operators from other angle. Essentially, they expect that for the (R) -> T the signature for e.g. pullback would be (T.Tan) -> R.Tan. And this might not be the case for one reason or another.

If you wanted to nail down useful semantics for class differentiation (e.g. whether to infer class-typed arguments as always implicitly being inout), perhaps you could open a separate issue where we could discuss.

I think we already have plenty of those, e.g. #55542 and lots of fixmes in the tests / code :)

@dan-zheng
Copy link
Contributor

Well, my point is they are not orthogonal if we look onto differential operators from other angle.

I see, that's a fair point.

So, looks like here we'd essentially need to extend the typechecker to disallow non-value-typed arguments in general?

My initial reaction: yes, further constraining differential operator types (to disallow non-sensible cases like class-typed arguments) makes sense to me. Other folks might also have different ideas.

+1 from me for being use-case-driven in the differential operator design and working to support real non-toy use cases (I'm not sure what these are for class differentiation).

@asl
Copy link
Contributor Author

asl commented Aug 9, 2023

+1 from me for being use-case-driven in the differential operator design and working to support real non-toy use cases (I'm not sure what these are for class differentiation).

My main usecase here is quite simple. And this probably would cover 80% of sane usecases: Array. Maybe Dictionary as well, but with much less footprint.

@dan-zheng
Copy link
Contributor

I'm not sure I understood your comment above, since Array and Dictionary are value-semantics structs, not classes.

How about constraining differential operators to disallow functions with class arguments (i.e. AnyObject typed arguments)? Differential operators would continue to work with Array and Dictionary since they're not classes.

@asl
Copy link
Contributor Author

asl commented Aug 9, 2023

I'm not sure I understood your comment above, since Array and Dictionary are value-semantics structs, not classes.

How about constraining differential operators to disallow functions with class arguments (i.e. AnyObject typed arguments)? Differential operators would continue to work with Array and Dictionary since they're not classes.

Yes, this is something I also thought about. Will try experimenting with it. Thanks!

@asl
Copy link
Contributor Author

asl commented May 29, 2024

@JaapWijnen Found this issue with some additional things

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.
Projects
None yet
Development

No branches or pull requests

2 participants