Skip to content

[AutoDiff] Supporting differentiable functions with multiple semantic results #66873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 6, 2023

Conversation

asl
Copy link
Contributor

@asl asl commented Jun 22, 2023

PR #32629 added reverse-mode differentiation support for apply instructions with multiple active semantic results. This completes user-facing support for differentiable functions with multiple semantic results.

Previously, it was not possible to state that a function with multiple semantic results was @differentiable. This included:

  • functions with an inout parameter which returned a result
  • functions with multiple inout parameters
  • mutating functions which returned a result
  • functions that return a tuple of results

It is now possible to mark these functions as @differentiable and to supply custom pullbacks for them.

This is essentially #38781 rebased on main with additional bugfixes and some changes here and there

Co-authored-by: @BradLarson

@asl
Copy link
Contributor Author

asl commented Jun 22, 2023

@swift-ci please test

@asl
Copy link
Contributor Author

asl commented Jun 23, 2023

@swift-ci please test macos platform

@asl
Copy link
Contributor Author

asl commented Jun 23, 2023

@rxwei @BradLarson @dan-zheng Let me know if there is something that needs to be added / changed :)

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I have just a few questions:

  • Does this mean that we lose support for inout parameters that are not intended to be treated as a semantic result? Does this break any existing code with an inout parameters?
  • In SIL we have "result indices" to choose which results to differentiate. How are we sorting the inout results with other results? Could you add some SIL FileCheck tests?
  • Should result indices be surfaced to the language so that we can express it in the @differentiable attribute?

@asl
Copy link
Contributor Author

asl commented Jun 23, 2023

  • Does this mean that we lose support for inout parameters that are not intended to be treated as a semantic result? Does this break any existing code with an inout parameters?

I do not think so. @BradLarson do you have some production-code example in mind when this might matter?

  • In SIL we have "result indices" to choose which results to differentiate. How are we sorting the inout results with other results? Could you add some SIL FileCheck tests?
  • Should result indices be surfaced to the language so that we can express it in the @differentiable attribute?

Currently inout handling is sprinkled around codebase here and there as a special case. I am working on class differentiation on top of this PR (with the intention to support Arrays without manual hacks). For this I needed to generalize the meaning of "semantic result parameter", so I can treat inouts and class references somehow in a similar way. I think we can postpone adding result bits to the language until we will have some common support for these things.

So far the semantic result indices will always go after the usual result indices in the declaration order (so they are essentially numResults + original parameter index). I think we will need to make this user-friendly when we will expose them to the users :)

@BradLarson
Copy link
Contributor

@asl - Last time around, I'd run these changes against our entire codebase and didn't see any cases where it caused a problem, even in speculative differentiable code that would then be able to take advantage of multiple results. I know we have at least one test case with multiple inout parameters and a wrt to ignore one of them for purposes of differentiability as an input. That case was handled correctly in terms of the arguments, but I don't recall if the ignored inout parameter was still present as an active result.

Regarding result indices, I don't think we were able to come up with a scenario where we needed to wrt results. There are common cases for wrt on arguments, but we couldn't determine when we'd need to do so for results.

@rxwei
Copy link
Contributor

rxwei commented Jun 24, 2023

Not being able to express result selection in @differentiable attribute can cause pretty bad ABI mismatches. IIRC we currently infer result indices from types conforming to Differentiable. Let's consider the following scenario:

Module1:

@differentiable
public func foo(x: Float, y: inout Foo) -> Float

// `@differentiable` implies the following derivative functions in Module1's ABI.
public func foo_jvp(x: Float, y: inout Foo) -> (result: Float, differential: (Float) -> Float)
public func foo_vjp(x: Float, y: inout Foo) -> (result: Float, pullback: (Float) -> Float)

Module2:

public extension Foo: Differentiable {}

Module3:

import Module1
import Module2

valueWithPullback(at: ...) {
    foo(...)
}

Because Module1's swiftinterface says @differentiable(wrt: x) without specifying the result indices, the result indices are always going to be inferred when the swiftinterface is type-checked. Module2 made Foo conform to Differentiable, so the compiler is now considering the inout Foo parameter as a differentiable semantic result, and will infer wrong result indices for declaring a differentiability witness (and getting the wrong JVP and VJP), hence undefined symbols.

If that's indeed what's happening here, I think we should mitigate that before merging this PR. One possible way to deal with this, before we have any user-facing syntax for expressing result selection in @differentiable, is to never treat an inout parameter as a semantic result unless it's an "wrt" parameter, because "wrt" guarantees that the inout parameter conforms to Differentiable when it's being differentiated, both when it's an wrt parameter and when it's a semantic result.

@asl
Copy link
Contributor Author

asl commented Jun 24, 2023

@rxwei Thanks for the example. Let me investigate the things.

@asl
Copy link
Contributor Author

asl commented Jun 25, 2023

@rxwei Actually the issue is more obvious. In your example we will be unable to infer the tangent type for Foo in Module1.

@rxwei
Copy link
Contributor

rxwei commented Jun 25, 2023

@rxwei Actually the issue is more obvious. In your example we will be unable to infer the tangent type for Foo in Module1.

Do you mean you are unconditionally treating all inout parameters as semantic results? This would be a source-breaking change. Existing code like the following would be broken and will be impossible to express after this PR.

@differentiable(wrt: input)
func prediction(from input: Vector, in context: inout Context) -> Vector

After thinking about this a bit more, I don't believe it should be the default behavior. It seems quite confusing to the user to treat one or more inout parameters as "results" when a function also has formal results, let alone the source breakage.

One possible way to deal with this, before we have any user-facing syntax for expressing result selection in @differentiable, is to never treat an inout parameter as a semantic result unless it's an "wrt" parameter

I believe such confusion would be resolved by the above suggestion, especially in cases where the function has a non-Void return type.

@asl
Copy link
Contributor Author

asl commented Jun 25, 2023

@rxwei Yes. With this PR we're asserting on the following code (instead of valid diagnostics):

public struct ArrayWrapper {
    var values: [Float]

    mutating func get(index: Int) -> Float {
        self.values[index]
    }
}

@differentiable(reverse)
func test(x: Int, y: inout ArrayWrapper, z: Float) {
	y.get(index: x) + z
}

And assert instead of compiling:

public struct ArrayWrapper {
    var values: [Float]

    mutating func get(index: Int) -> Float {
        self.values[index]
    }
}

@differentiable(reverse)
func test(x: Int, y: inout ArrayWrapper, z: Float) -> Float {
	y.get(index: x) + z
}

So yes, we need take into account only wrt semantic results

@rxwei
Copy link
Contributor

rxwei commented Jun 25, 2023

In the case where a function has both inout parameters and formal results, it may be better to treat an inout parameter as a semantic result only when the user explicitly declares wrt: for that parameter (even if the type conforms to Differentiable.)

@asl
Copy link
Contributor Author

asl commented Jun 27, 2023

@rxwei Ok, I think here is the important point. Currently we do support derivatives of void functions with non-wrt inout (#33304), essentially:

protocol Proto {
  @differentiable(reverse, wrt: x)
  func method(x: Float, y: inout Float)
}

struct Struct: Proto {
  @differentiable(reverse, wrt: x)
  func method(x: Float, y: inout Float) {
    y = y * x
  }
}

So, are you suggesting the following:

  1. For void functions, treat all inouts as semantic results (and all of them should be differentiable)
  2. For non-void functions, treat only wrt inouts as semantic results

Is this correct?

@rxwei
Copy link
Contributor

rxwei commented Jun 28, 2023

I think there are two angles where we need to decide:

  • Whether non-wrt parameters should be treated as semantic results.
    • It seems that our conclusion was "no".
  • The default wrt inference behavior, i.e. the behavior when you apply @differentiable without specifying wrt:.
    • My current preference is to never infer wrt-ness on any inout argument. If the user explicitly specifies wrt: on an inout argument, we treat it as a semantic result.
    • Enforcing explicit wrt annotations has better clarity. We can always add inference rules in the future if there's a usability benefit, without breaking source code.
@differentiable // inferred as wrt: x, semantic result: formal result (no y)
func foo(x: Float, y: inout Float) -> Float

@differentiable // error: 'String' does not conform to 'Differentiable'
func foo(x: Float, y: inout Float) -> String

@differentiable // error: function does not have a differentiable return type
                // note: did you mean to differentiate wrt the inout parameter `y`?
                // fixit: insert `(wrt: y)`
func foo(x: Float, y: inout Float)

@asl
Copy link
Contributor Author

asl commented Jun 28, 2023

  • My current preference is to never infer wrt-ness on any inout argument. If the user explicitly specifies wrt: on an inout argument, we treat it as a semantic result.

I am thinking we'd need to make different rules for ordinary functions vs methods. In the latter case it might make sense to differentiate around inout Self for the mutating methods. Otherwise the things might be confusing as self is implicit there.

@asl
Copy link
Contributor Author

asl commented Jun 28, 2023

@swift-ci please test

@asl
Copy link
Contributor Author

asl commented Jun 28, 2023

@rxwei @BradLarson Ok, while we are deciding on what would be the best way to handle inout semantic results, here is the commit (1ab1598) that I believe should fix immediate problems while keeping the existing logic intact (plus things are factored in such way that it would be easy to change inout handling later on).

  1. wrt inouts are treated as semantic results
  2. Unless function returns Void non-wrt inouts are not considered as semantic results
  3. All inouts are considered as semantic results for void functions
  4. All semantic results are required to be differentiable, we do not silently skip them

I think now we are correctly handing result indices in all cases (e.g. for methods where Self is curried).

PTAL

@BradLarson
Copy link
Contributor

The change to requiring wrt for inout parameters with non-Void returns will lead to a slight change in behavior in the case where there's a single inout parameter and a non-differentiable result, which previously did not require a wrt annotation. I'm fine with annotating the few places in our code that use a function like that, but would that cause a problem for any other existing code out there?

@rxwei
Copy link
Contributor

rxwei commented Jul 2, 2023

Unless function returns Void non-wrt inouts are not considered as semantic results

I'm trying to understand this clearly. Is this saying: if a function returns Void, non-wrt inputs are semantic results? What about the following case:

@differentiable(wrt: x)
func foo(x: Float, y: inout Float)

Are you saying that the function above treats y as a semantic result even it's not wrt? If so, it seems contradicting the earlier conclusion as I understood it.

@asl
Copy link
Contributor Author

asl commented Jul 2, 2023

Are you saying that the function above treats y as a semantic result even it's not wrt? If so, it seems contradicting the earlier conclusion as I understood it.

As I said, the PR keeps the existing behavior for void functions and semantic results – for void functions all inouts are treated as semantic results regardless whether they are wrt or not. In order to overcome possible ABI breakage we require all semantic results to be differentiable (and not just silently skip as it was before).

My main concern for now are methods: Self is inout for struct mutating methods (and always a result for class methods), however it is implicit for the user. Are we going to require explicit wrt Self here?

@rxwei
Copy link
Contributor

rxwei commented Jul 6, 2023

Can we have a follow-up PR to tighten up the wrt inference a little bit? The existing implicit behavior seems too complex.

My main concern for now are methods: Self is inout for struct mutating methods (and always a result for class methods), however it is implicit for the user. Are we going to require explicit wrt Self here?

I don't think there is harm to require explicit wrt: self.

@asl
Copy link
Contributor Author

asl commented Jul 6, 2023

Can we have a follow-up PR to tighten up the wrt inference a little bit? The existing implicit behavior seems too complex.

Absolutely! I think we'd need to prepare some set of cases / code examples and decided on all of them

@asl
Copy link
Contributor Author

asl commented Jul 6, 2023

I created #67174 to track the refined wrt semantics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants