-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Supporting differentiable functions with multiple semantic results #66873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@swift-ci please test |
@swift-ci please test macos platform |
@rxwei @BradLarson @dan-zheng Let me know if there is something that needs to be added / changed :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I have just a few questions:
- Does this mean that we lose support for inout parameters that are not intended to be treated as a semantic result? Does this break any existing code with an inout parameters?
- In SIL we have "result indices" to choose which results to differentiate. How are we sorting the inout results with other results? Could you add some SIL FileCheck tests?
- Should result indices be surfaced to the language so that we can express it in the
@differentiable
attribute?
I do not think so. @BradLarson do you have some production-code example in mind when this might matter?
Currently inout handling is sprinkled around codebase here and there as a special case. I am working on class differentiation on top of this PR (with the intention to support Arrays without manual hacks). For this I needed to generalize the meaning of "semantic result parameter", so I can treat inouts and class references somehow in a similar way. I think we can postpone adding result bits to the language until we will have some common support for these things. So far the semantic result indices will always go after the usual result indices in the declaration order (so they are essentially |
@asl - Last time around, I'd run these changes against our entire codebase and didn't see any cases where it caused a problem, even in speculative differentiable code that would then be able to take advantage of multiple results. I know we have at least one test case with multiple inout parameters and a Regarding result indices, I don't think we were able to come up with a scenario where we needed to |
Not being able to express result selection in Module1: @differentiable
public func foo(x: Float, y: inout Foo) -> Float
// `@differentiable` implies the following derivative functions in Module1's ABI.
public func foo_jvp(x: Float, y: inout Foo) -> (result: Float, differential: (Float) -> Float)
public func foo_vjp(x: Float, y: inout Foo) -> (result: Float, pullback: (Float) -> Float) Module2: public extension Foo: Differentiable {} Module3: import Module1
import Module2
valueWithPullback(at: ...) {
foo(...)
} Because Module1's swiftinterface says If that's indeed what's happening here, I think we should mitigate that before merging this PR. One possible way to deal with this, before we have any user-facing syntax for expressing result selection in |
@rxwei Thanks for the example. Let me investigate the things. |
@rxwei Actually the issue is more obvious. In your example we will be unable to infer the tangent type for |
Do you mean you are unconditionally treating all inout parameters as semantic results? This would be a source-breaking change. Existing code like the following would be broken and will be impossible to express after this PR. @differentiable(wrt: input)
func prediction(from input: Vector, in context: inout Context) -> Vector After thinking about this a bit more, I don't believe it should be the default behavior. It seems quite confusing to the user to treat one or more inout parameters as "results" when a function also has formal results, let alone the source breakage.
I believe such confusion would be resolved by the above suggestion, especially in cases where the function has a non- |
@rxwei Yes. With this PR we're asserting on the following code (instead of valid diagnostics): public struct ArrayWrapper {
var values: [Float]
mutating func get(index: Int) -> Float {
self.values[index]
}
}
@differentiable(reverse)
func test(x: Int, y: inout ArrayWrapper, z: Float) {
y.get(index: x) + z
} And assert instead of compiling: public struct ArrayWrapper {
var values: [Float]
mutating func get(index: Int) -> Float {
self.values[index]
}
}
@differentiable(reverse)
func test(x: Int, y: inout ArrayWrapper, z: Float) -> Float {
y.get(index: x) + z
} So yes, we need take into account only |
In the case where a function has both inout parameters and formal results, it may be better to treat an inout parameter as a semantic result only when the user explicitly declares |
@rxwei Ok, I think here is the important point. Currently we do support derivatives of void functions with non-wrt inout (#33304), essentially: protocol Proto {
@differentiable(reverse, wrt: x)
func method(x: Float, y: inout Float)
}
struct Struct: Proto {
@differentiable(reverse, wrt: x)
func method(x: Float, y: inout Float) {
y = y * x
}
} So, are you suggesting the following:
Is this correct? |
I think there are two angles where we need to decide:
@differentiable // inferred as wrt: x, semantic result: formal result (no y)
func foo(x: Float, y: inout Float) -> Float
@differentiable // error: 'String' does not conform to 'Differentiable'
func foo(x: Float, y: inout Float) -> String
@differentiable // error: function does not have a differentiable return type
// note: did you mean to differentiate wrt the inout parameter `y`?
// fixit: insert `(wrt: y)`
func foo(x: Float, y: inout Float) |
I am thinking we'd need to make different rules for ordinary functions vs methods. In the latter case it might make sense to differentiate around |
@swift-ci please test |
@rxwei @BradLarson Ok, while we are deciding on what would be the best way to handle inout semantic results, here is the commit (1ab1598) that I believe should fix immediate problems while keeping the existing logic intact (plus things are factored in such way that it would be easy to change inout handling later on).
I think now we are correctly handing result indices in all cases (e.g. for methods where PTAL |
The change to requiring |
I'm trying to understand this clearly. Is this saying: if a function returns @differentiable(wrt: x)
func foo(x: Float, y: inout Float) Are you saying that the function above treats |
As I said, the PR keeps the existing behavior for void functions and semantic results – for void functions all inouts are treated as semantic results regardless whether they are wrt or not. In order to overcome possible ABI breakage we require all semantic results to be differentiable (and not just silently skip as it was before). My main concern for now are methods: |
Can we have a follow-up PR to tighten up the wrt inference a little bit? The existing implicit behavior seems too complex.
I don't think there is harm to require explicit |
Absolutely! I think we'd need to prepare some set of cases / code examples and decided on all of them |
I created #67174 to track the refined |
PR #32629 added reverse-mode differentiation support for apply instructions with multiple active semantic results. This completes user-facing support for differentiable functions with multiple semantic results.
Previously, it was not possible to state that a function with multiple semantic results was
@differentiable
. This included:It is now possible to mark these functions as
@differentiable
and to supply custom pullbacks for them.This is essentially #38781 rebased on
main
with additional bugfixes and some changes here and thereCo-authored-by: @BradLarson