Skip to content

[Sema] Differentiable conformance derivation for class types. #25914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dan-zheng
Copy link
Contributor

@dan-zheng dan-zheng commented Jul 1, 2019

Differentiable derived conformances now supports class types.
Synthesis works just like for struct types.

Class differentiation support requires further differentiation transform changes.

Resolves TF-630.

Next steps:

  • TF-631: Enable differentiation of class methods (handle ref_element_addr in differentiation transform).
  • TF-633: Update tensorflow-swift-apis using Differentiable-conforming class types.

class Example<T : Differentiable> : Differentiable {
  var x, y: T
  @noDerivative var flag: Bool = true

  // Unlike structs, classes do not get synthesized memberwise initializers.
  // User must define initializer or provide default values for stored properties.
  init(x: T, y: T) {
    self.x = x
    self.y = y
  }

  // Compiler synthesizes (omitting `AllDifferentiableVariables` code):
  //
  // struct TangentVector : Differentiable, AdditiveArithmetic {
  //   var x: T.TangentVector
  //   var y: T.TangentVector
  //   init(x: T.TangentVector, y: T.TangentVector)
  //   typealias TangentVector = Example<T>.TangentVector
  //   final static var zero: Example<T>.TangentVector { get }
  //   static func + (lhs: Example<T>.TangentVector, rhs: Example<T>.TangentVector) -> Example<T>.TangentVector
  //   static func - (lhs: Example<T>.TangentVector, rhs: Example<T>.TangentVector) -> Example<T>.TangentVector
  //   @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Example<T>.TangentVector, _ b: Example<T>.TangentVector) -> Bool
  // }
  //
  // func move(along direction: Example<T>.TangentVector) {
  //   x.move(along: direction.x)
  //   y.move(along: direction.y)
  // }
}

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Jul 1, 2019
@dan-zheng dan-zheng requested a review from rxwei July 1, 2019 19:08
`Differentiable` derived conformances now supports class types.
Synthesis works just like for struct types.

Class differentiation support requires further differentiation
transform changes.

Resolves TF-630.
@dan-zheng dan-zheng force-pushed the class-derived-differentiable branch from 939b132 to bf45e85 Compare July 1, 2019 19:19
@@ -0,0 +1,570 @@
// SWIFT_ENABLE_TENSORFLOW
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are not the best, I directly copied test/Sema/struct_differentiable.swift.
Todo: clean up and tighten tests. Fuzzing would be cool.

@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test that guarantees that TangentVector will never be derived to be Self for classes, even if Self conforms to AdditiveArithmetic?

@dan-zheng
Copy link
Contributor Author

dan-zheng commented Jul 1, 2019

Could you add a test that guarantees that TangentVector will never be derived to be Self for classes, even if Self conforms to AdditiveArithmetic?

I actually didn't implement that logic, will do so now.
Done in 008789b.

Fix validation-test/ParseableInterface/verify_all_overlays.py.
Address review comments.
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit d214a91 into swiftlang:tensorflow Jul 1, 2019
@dan-zheng dan-zheng deleted the class-derived-differentiable branch July 1, 2019 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants