-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Support inout
argument differentiation.
#30013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Support inout
argument differentiation.
#30013
Conversation
Note: merging
This is blocking the Lesson learned: next time when doing changes that touch both upstream and downstream code, create PRs to both |
🎉 |
Semantically, an `inout` parameter is both a parameter and a result. `@differentiable` and `@derivative` attributes now support original functions with one "semantic result": either a formal result or an `inout` parameter. Derivative typing rules for functions with `inout` parameters are now defined. The differential/pullback type of a function with `inout` differentiability parameters also has `inout` parameters. This is ideal for performance. Differential typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, ...) -> T1.Tan` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` Pullback typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` Resolves TF-1164.
a1ed40c
to
759f338
Compare
apply
with inout
arguments.inout
argument differentiation.
Add reverse-mode differentiation support for `apply` with `inout` arguments. Notable pullback generation changes: - If the pullback seed argument is `inout`, assign it (rather than a copy) directly as the adjoint buffer of the original result. This is important so the value is updated in-place. - In `visitApplyInst`: skip adjoint accumulation for `inout` arguments. Adjoint accumulation for `inout` arguments occurs when callee pullbacks are applied, so no extra accumulation is necessary. Add derivatives for functions with `inout` parameters in the stdlib for testing: - `FloatingPoint` operations: `+=`, `-=`, `*=`, `/=` - `Array.append` Resolves TF-1165. Todos: - Add more tests, e.g. SILGen tests for `inout` derivative typing rules. - Evaluate performance of `inout` derivatives vs functional derivatives + mutation. - TF-1166: enable `@differentiable` attribute on `set` accessors. - TF-1173: add forward-mode differentiation support for `apply` with `inout` parameters.
759f338
to
f1a604d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test cases:
- Differentiating a function with
inout
parameters that have aclass
type. This is expected to work normally. - Differentiating a function whose body calls a same-file function that both takes a
Differentiable
inout
argument and returns aDifferentiable
formal result where both theinout
argument and the formal result are active. This is expected to trigger a differentiation error.
Fixes `-O` test failure: ``` Failing Tests (1): Swift(linux-x86_64) :: AutoDiff/downstream/inout_parameters.swift ```
Class-typed arguments are not always marked active. This produces incorrect derivative results.
Mutating `inout` class argument via `store`: correct derivatives. Mutating `inout` class argument via `modify` accessor: incorrect derivatives. Tracked at TF-1176.
398c13f
to
955b594
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any forward mode tests?
Forward-mode differentiation doesn't support |
Let's standardize on `unsigned`.
@swift-ci Please test tensorflow |
Change the following: - `ActionValueCalculator` from a struct to a class. - `ActionValueCalculator.loss` from a mutating method (necessary for structs) to a non-mutating one (unnecessary for classes). This fixes a new type-checking error after `inout` argument differentiation typing rules, added in swiftlang/swift#30013: ``` ExploitabilityDescent.swift:218:4: error: cannot differentiate between functions with both an 'inout' parameter and a result @differentiable(wrt: policy) ``` PiperOrigin-RevId: 297457176 Change-Id: I3374d45c1d56124e20973a7f262bd10bd6e06053
Includes cherry-pick of #29959: typing rules for
inout
parameter differentiation.Add reverse-mode differentiation support for
apply
withinout
arguments.Notable pullback generation changes:
inout
, assign it (rather than a copy)directly as the adjoint buffer of the original result. This is important so
the value is updated in-place.
visitApplyInst
: skip adjoint accumulation forinout
arguments.Adjoint accumulation for
inout
arguments occurs when callee pullbacks areapplied, so no extra accumulation is necessary.
Add derivatives for functions with
inout
parameters in the stdlib for testing:FloatingPoint
operations:+=
,-=
,*=
,/=
Array.append
Resolves TF-1165.
Todos:
inout
derivative typing rules.inout
derivatives vs functional derivatives + mutation.@differentiable
attribute onset
accessors.apply
withinout
parameters.
Exposes TF-1175: incorrect activity for class arguments.
Exposes TF-1176: incorrect activity for class
modify
accessors.Add negative tests.
Examples:
Thanks @marcrasi for peer debugging incorrect
inout
argument derivative values!We discovered that
Pullback::visitApplyInst
should skip adjoint accumulation forinout
arguments, which is interesting and not obvious. This made the first exciting test cases pass.Working together was fun and productive 🙂