Skip to content

Refine the implicit wrt logic for @differentiable attribute #67174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
asl opened this issue Jul 6, 2023 · 4 comments · Fixed by #67230
Closed

Refine the implicit wrt logic for @differentiable attribute #67174

asl opened this issue Jul 6, 2023 · 4 comments · Fixed by #67230
Labels

Comments

@asl
Copy link
Contributor

asl commented Jul 6, 2023

Originally posted by @rxwei in #66873 (comment)

I think there are two angles where we need to decide:

  • Whether non-wrt parameters should be treated as semantic results.
    • It seems that our conclusion was "no".
  • The default wrt inference behavior, i.e. the behavior when you apply @differentiable without specifying wrt:.
    • My current preference is to never infer wrt-ness on any inout argument. If the user explicitly specifies wrt: on an inout argument, we treat it as a semantic result.
    • Enforcing explicit wrt annotations has better clarity. We can always add inference rules in the future if there's a usability benefit, without breaking source code.
@differentiable // inferred as wrt: x, semantic result: formal result (no y)
func foo(x: Float, y: inout Float) -> Float

@differentiable // error: 'String' does not conform to 'Differentiable'
func foo(x: Float, y: inout Float) -> String

@differentiable // error: function does not have a differentiable return type
                // note: did you mean to differentiate wrt the inout parameter `y`?
                // fixit: insert `(wrt: y)`
func foo(x: Float, y: inout Float)
@asl asl added the AutoDiff label Jul 6, 2023
@asl
Copy link
Contributor Author

asl commented Jul 6, 2023

@rxwei
Copy link
Contributor

rxwei commented Jul 6, 2023

To clarify what I was arguing for: I think inference for simple cases (no inout parameters) is pretty good, but it gets complicated with multiple semantic results.

@asl
Copy link
Contributor Author

asl commented Jul 11, 2023

@rxwei Does this mean that we'd always require setters to be differentiated wrt self? So things like in SILGen/witness_table.swift would be disallowed?

protocol Protocol: Differentiable {
  @differentiable(reverse, wrt: (self, x, y))
  @differentiable(reverse, wrt: x)
  func method(_ x: Float, _ y: Double) -> Float

  @differentiable(reverse)
  var property: Float { get set }

  @differentiable(reverse, wrt: x)
  subscript(_ x: Float, _ y: Float) -> Float { get set }
}

@asl
Copy link
Contributor Author

asl commented Jul 11, 2023

So, we are having special handling in inouts in many places. In preparation for future changes (treating class reference as potential semantic results, etc.), I made a refactoring introducing "semantic result parameters". I believe I walked over all places where inouts were handled and checked if the logic is specific to inouts or in general to semantic result parameters and made the corresponding changes (though I could certainly miss something obscure or some corner cases).

On top of that, 918e243 implements the logic we discussed: semantic result parameters (inouts for now) are not considered as semantic results unless they are wrt parameters.

Please take a look into #67230

@asl asl closed this as completed in #67230 Aug 3, 2023
asl added a commit that referenced this issue Aug 3, 2023
Introduce the notion of "semantic result parameter". Handle differentiation of inouts via semantic result parameter abstraction. Do not consider non-wrt semantic result parameters as semantic results

Fixes #67174
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants