Skip to content

[AutoDiff] Fix differentiation for non-wrt inout parameters. #33304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2020

Conversation

dan-zheng
Copy link
Contributor

Fix SIL differential function type calculation to handle non-wrt inout
parameters.

Patch SILFunctionType::getDifferentiabilityResultIndices to prevent returning
empty result indices for @differentiable function types with no formal results
where all inout parameters are @noDerivative. TF-1305 tracks a robust fix.

Resolves SR-13305.

Exposes TF-1305: parameter/result differentiability hole for inout parameters.

Comment on lines 238 to 251
// FIXME(TF-1305): The `getResults().empty()` condition is a hack.
//
// Currently, an `inout` parameter can either be:
// 1. Both a differentiability parameter and a differentiability result.
// 2. `@noDerivative`: neither a differentiability parameter nor a
// differentiability result.
// However, there is no way to represent an `inout` parameter that:
// 3. Is a differentiability parameter but not a differentiability result.
// 4. Is a differentiability result but not a differentiability parameter.
//
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout`
// parameters are not treated as differentiability results, unless the
// original function has no formal results, which case all `inout`
// parameters are treated as differentiability results.
Copy link
Contributor Author

@dan-zheng dan-zheng Aug 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick paste of solution ideas from TF-1305:

  • Add new SILParameterDifferentiability kinds, e.g. InoutParameterNotDifferentiableParameter (case 3) and InoutParameterNotDifferentiableResult (case 4).
    • Consider whether this needs to be exposed in AST @differentiable function types.
  • Find some way to store parameter/result differentiability in SILFunctionType instead of individual SILParameterInfo/SILResultInfo.

@dan-zheng dan-zheng requested review from marcrasi and rxwei August 5, 2020 00:40
Fix SIL differential function type calculation to handle non-wrt `inout`
parameters.

Patch `SILFunctionType::getDifferentiabilityResultIndices` to prevent returning
empty result indices for `@differentiable` function types with no formal results
where all `inout` parameters are `@noDerivative`. TF-1305 tracks a robust fix.

Resolves SR-13305.

Exposes TF-1305: parameter/result differentiability hole for `inout` parameters.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test

Comment on lines 246 to 248
// 4. Is a differentiability parameter but not a differentiability result.
// This case does not yet have clear use cases, so supporting it is a
// non-goal.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei mentioned a nuanced point, which is that there are no clear use cases for inout parameters that are differentiability parameters but not differentiability results (case 4).

That case is indeed not currently expressible, so we can deprioritize support for it: I clarified that in the comment. Supporting that case would involve something like adding a results: clause to @differentiable declaration attribute:

// Pseudo-syntax for supporting case 4:
// `inout` parameter that is a differentiability parameter but not a differentiability result.
@differentiable(wrt: inoutParam, results: return)
func foo(_ inoutParam: inout Float) -> Float {
  return inoutParam
}

The focus of this PR is supporting inout parameters that are differentiability results but that are not differentiability parameters (case 3), which is currently expressible and reasonable to support.

// Case 3: `inout` parameter that is a differentiability result but not a differentiability parameter.
@differentiable(wrt: x)
func foo(_ x: Float, _ inoutParam: inout Float) {
  inoutParam = x * x
}

… but not diff. results.

There are no clear use cases for `inout` parameters that are differentiability
parameters but not differentiability results, so we can de-prioritize support
for it.

Clarify this in the comment regarding TF-1305.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test

@swift-ci
Copy link
Contributor

swift-ci commented Aug 5, 2020

Build failed
Swift Test Linux Platform
Git Sha - ffe8d78

@swift-ci
Copy link
Contributor

swift-ci commented Aug 5, 2020

Build failed
Swift Test OS X Platform
Git Sha - ffe8d78

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants