-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Fix differentiation for non-wrt inout
parameters.
#33304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
lib/SIL/IR/SILFunctionType.cpp
Outdated
// FIXME(TF-1305): The `getResults().empty()` condition is a hack. | ||
// | ||
// Currently, an `inout` parameter can either be: | ||
// 1. Both a differentiability parameter and a differentiability result. | ||
// 2. `@noDerivative`: neither a differentiability parameter nor a | ||
// differentiability result. | ||
// However, there is no way to represent an `inout` parameter that: | ||
// 3. Is a differentiability parameter but not a differentiability result. | ||
// 4. Is a differentiability result but not a differentiability parameter. | ||
// | ||
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout` | ||
// parameters are not treated as differentiability results, unless the | ||
// original function has no formal results, which case all `inout` | ||
// parameters are treated as differentiability results. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick paste of solution ideas from TF-1305:
- Add new SILParameterDifferentiability kinds, e.g.
InoutParameterNotDifferentiableParameter
(case 3) andInoutParameterNotDifferentiableResult
(case 4).- Consider whether this needs to be exposed in AST
@differentiable
function types.
- Consider whether this needs to be exposed in AST
- Find some way to store parameter/result differentiability in
SILFunctionType
instead of individualSILParameterInfo
/SILResultInfo
.
Fix SIL differential function type calculation to handle non-wrt `inout` parameters. Patch `SILFunctionType::getDifferentiabilityResultIndices` to prevent returning empty result indices for `@differentiable` function types with no formal results where all `inout` parameters are `@noDerivative`. TF-1305 tracks a robust fix. Resolves SR-13305. Exposes TF-1305: parameter/result differentiability hole for `inout` parameters.
@swift-ci Please test |
lib/SIL/IR/SILFunctionType.cpp
Outdated
// 4. Is a differentiability parameter but not a differentiability result. | ||
// This case does not yet have clear use cases, so supporting it is a | ||
// non-goal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxwei mentioned a nuanced point, which is that there are no clear use cases for inout
parameters that are differentiability parameters but not differentiability results (case 4).
That case is indeed not currently expressible, so we can deprioritize support for it: I clarified that in the comment. Supporting that case would involve something like adding a results:
clause to @differentiable
declaration attribute:
// Pseudo-syntax for supporting case 4:
// `inout` parameter that is a differentiability parameter but not a differentiability result.
@differentiable(wrt: inoutParam, results: return)
func foo(_ inoutParam: inout Float) -> Float {
return inoutParam
}
The focus of this PR is supporting inout
parameters that are differentiability results but that are not differentiability parameters (case 3), which is currently expressible and reasonable to support.
// Case 3: `inout` parameter that is a differentiability result but not a differentiability parameter.
@differentiable(wrt: x)
func foo(_ x: Float, _ inoutParam: inout Float) {
inoutParam = x * x
}
… but not diff. results. There are no clear use cases for `inout` parameters that are differentiability parameters but not differentiability results, so we can de-prioritize support for it. Clarify this in the comment regarding TF-1305.
@swift-ci Please test |
Build failed |
Build failed |
Fix SIL differential function type calculation to handle non-wrt
inout
parameters.
Patch
SILFunctionType::getDifferentiabilityResultIndices
to prevent returningempty result indices for
@differentiable
function types with no formal resultswhere all
inout
parameters are@noDerivative
. TF-1305 tracks a robust fix.Resolves SR-13305.
Exposes TF-1305: parameter/result differentiability hole for
inout
parameters.