diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index d9a4237da07f7..6cf008c82e72c 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2869,9 +2869,9 @@ ERROR(compiler_evaluable_ref_non_compiler_evaluable,none, "@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ()) // @noDerivative attribute -ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none, - "'@noDerivative' is only allowed on stored properties in structure types " - "that declare a conformance to 'Differentiable'", ()) +ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none, + "'@noDerivative' is only allowed on stored properties in structure or " + "class types that declare a conformance to 'Differentiable'", ()) //------------------------------------------------------------------------------ // MARK: Type Check Expressions diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 75ba87eda662d..ef366c432a494 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -13,7 +13,7 @@ // SWIFT_ENABLE_TENSORFLOW // // This file implements explicit derivation of the Differentiable protocol for -// struct types. +// struct and class types. // //===----------------------------------------------------------------------===// @@ -108,8 +108,7 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) { auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); assert(diffableProto && "`Differentiable` protocol not found"); auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(), - diffableProto, - DC, None); + diffableProto, DC, None); assert(conf && "Nominal must conform to `Differentiable`"); Type assocType = conf->getTypeWitnessByName(DC->getSelfTypeInContext(), id); assert(assocType && "`Differentiable` protocol associated type not found"); @@ -120,9 +119,8 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) { bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, DeclContext *DC) { - // Nominal type must be a struct. (Zero stored properties is okay.) - auto *structDecl = dyn_cast(nominal); - if (!structDecl) + // Nominal type must be a struct or class. (No stored properties is okay.) + if (!isa(nominal) && !isa(nominal)) return false; auto &C = nominal->getASTContext(); auto *lazyResolver = C.getLazyResolver(); @@ -153,8 +151,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, // `X == X.TangentVector`. if (nominal->isImplicit() && structDecl == nominal->getDeclContext() && TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(), - diffableProto, DC, - None)) + diffableProto, DC, None)) return structDecl; // 3. Equal nominal (and conform to `AdditiveArithmetic` if flag is true). if (structDecl == nominal) { @@ -199,7 +196,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, // initializers that initialize all stored properties, including initial // value information. SmallVector diffProperties; - getStoredPropertiesForDifferentiation(structDecl, DC, diffProperties); + getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); return llvm::all_of(diffProperties, [&](VarDecl *v) { if (!v->hasInterfaceType()) lazyResolver->resolveDeclSignature(v); @@ -325,7 +322,8 @@ static ValueDecl *deriveDifferentiable_method( /*Throws*/ false, SourceLoc(), /*GenericParams=*/nullptr, params, TypeLoc::withoutLoc(returnType), parentDC); - funcDecl->setSelfAccessKind(SelfAccessKind::Mutating); + if (!nominal->getSelfClassDecl()) + funcDecl->setSelfAccessKind(SelfAccessKind::Mutating); funcDecl->setImplicit(); funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context); @@ -804,6 +802,8 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, DeclContext* DC) { auto *diffableProto = TC.Context.getProtocol(KnownProtocolKind::Differentiable); + bool nominalCanDeriveAdditiveArithmetic = + DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC); for (auto *vd : nominal->getStoredProperties()) { if (!vd->hasInterfaceType()) TC.resolveDeclSignature(vd); @@ -814,8 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, continue; // Check whether to diagnose stored property. bool conformsToDifferentiable = - TC.conformsToProtocol(varType, diffableProto, nominal, - None).hasValue(); + TC.conformsToProtocol(varType, diffableProto, nominal, None).hasValue(); // If stored property should not be diagnosed, continue. if (conformsToDifferentiable && !vd->isLet()) continue; @@ -829,8 +828,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, // `Differentiable` protocol requirements all have default implementations // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable` // derived conformances will no longer be necessary. - bool nominalCanDeriveAdditiveArithmetic = - DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC); if (!conformsToDifferentiable) { TC.diagnose(loc, diag::differentiable_nondiff_type_implicit_noderivative_fixit, @@ -844,7 +841,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, vd->getName(), nominal->getName(), nominalCanDeriveAdditiveArithmetic) .fixItInsert(loc, "@noDerivative "); - } } @@ -954,6 +950,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties; // Check conditions for returning `Self`. + // - `Self` is not a class type. // - No `@noDerivative` stored properties exist. // - All stored properties must have specified associated type equal to // `Self`. @@ -971,7 +968,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, parentDC, None); // Return `Self` if conditions are met. - if (!hasNoDerivativeStoredProp && + if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() && (id == C.Id_AllDifferentiableVariables || (allMembersAssocTypeEqualsSelf && nominalConformsToAddArith))) { auto selfType = parentDC->getSelfTypeInContext(); diff --git a/lib/Sema/DerivedConformanceRingMathProtocols.cpp b/lib/Sema/DerivedConformanceRingMathProtocols.cpp index ea32876872e1d..30d288377f308 100644 --- a/lib/Sema/DerivedConformanceRingMathProtocols.cpp +++ b/lib/Sema/DerivedConformanceRingMathProtocols.cpp @@ -178,8 +178,10 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl, // If conformance reference is concrete, then use concrete witness // declaration for the operator. if (confRef->isConcrete()) - memberOpDecl = confRef->getConcrete()->getWitnessDecl( - operatorReq, C.getLazyResolver()); + if (auto *concreteMemberMethodDecl = + confRef->getConcrete()->getWitnessDecl(operatorReq, + C.getLazyResolver())) + memberOpDecl = concreteMemberMethodDecl; assert(memberOpDecl && "Member operator declaration must exist"); auto memberOpDRE = new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true); diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index c086b6ead2a92..2717c4a837558 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -554,7 +554,9 @@ DerivedConformance::declareDerivedPropertySetter(TypeChecker &tc, /*GenericParams*/ nullptr, params, TypeLoc(), parentDC); setterDecl->setImplicit(); setterDecl->setStatic(isStatic); - setterDecl->setSelfAccessKind(SelfAccessKind::Mutating); + // Set mutating if parent is not a class. + if (!parentDC->getSelfClassDecl()) + setterDecl->setSelfAccessKind(SelfAccessKind::Mutating); // If this is supposed to be a final method, mark it as such. assert(isFinal || !parentDC->getSelfClassDecl()); @@ -584,6 +586,10 @@ DerivedConformance::declareDerivedProperty(Identifier name, VarDecl *propDecl = new (C) VarDecl(/*IsStatic*/isStatic, VarDecl::Specifier::Var, /*IsCaptureList*/false, SourceLoc(), name, parentDC); + // SWIFT_ENABLE_TENSORFLOW + // TODO: Upstream this change to master. + if (isFinal && parentDC->getSelfClassDecl()) + propDecl->getAttrs().add(new (C) FinalAttr(/*Implicit*/ true)); propDecl->setImplicit(); propDecl->copyFormalAccessFrom(Nominal, /*sourceIsParentContext*/ true); propDecl->setInterfaceType(propertyInterfaceType); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index a52d96c5cb929..ef126eaa1d76d 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3759,20 +3759,20 @@ void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) { return; if (!vd || vd->isStatic()) { diagnoseAndRemoveAttr(attr, - diag::noderivative_only_on_stored_properties_in_differentiable_structs); + diag::noderivative_only_on_differentiable_struct_or_class_fields); return; } - auto *structDecl = dyn_cast(vd->getDeclContext()); - if (!structDecl) { + auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl(); + if (!nominal || (!isa(nominal) && !isa(nominal))) { diagnoseAndRemoveAttr(attr, - diag::noderivative_only_on_stored_properties_in_differentiable_structs); + diag::noderivative_only_on_differentiable_struct_or_class_fields); return; } if (!conformsToDifferentiable( - structDecl->getDeclaredInterfaceType(), - structDecl->getDeclContext())) { + nominal->getDeclaredInterfaceType(), + nominal->getDeclContext())) { diagnoseAndRemoveAttr(attr, - diag::noderivative_only_on_stored_properties_in_differentiable_structs); + diag::noderivative_only_on_differentiable_struct_or_class_fields); return; } } diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable.swift similarity index 83% rename from test/AutoDiff/derived_differentiable_properties.swift rename to test/AutoDiff/derived_differentiable.swift index c2a06a5a58564..11d5e12583725 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable.swift @@ -18,7 +18,7 @@ public struct Foo : Differentiable { // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables // CHECK-SILGEN-LABEL: // Foo.a.getter -// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float +// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { var a: Float @@ -82,9 +82,6 @@ struct GenericTanMember : Differentiable, AdditiveArithmetic var x: T.TangentVector } -// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized. -// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized. - // CHECK-AST-LABEL: internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable // CHECK-AST: internal var x: T.TangentVector // CHECK-AST: internal init(x: T.TangentVector) @@ -105,3 +102,32 @@ extension ConditionallyDifferentiable : Differentiable where T : Differentiable // CHECK-AST: public var x: T // CHECK-AST: internal init(x: T) // CHECK-AST: } + +// Verify that `TangentVector` is not synthesized to be `Self` for +// `AdditiveArithmetic`-conforming classes. +final class AdditiveArithmeticClass : AdditiveArithmetic, Differentiable { + var x, y: T + init(x: T, y: T) { + self.x = x + self.y = y + } + + // Dummy `AdditiveArithmetic` requirements. + static func == (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Bool { + fatalError() + } + static var zero: AdditiveArithmeticClass { + fatalError() + } + static func + (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self { + fatalError() + } + static func - (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self { + fatalError() + } +} + +// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable { +// CHECK-AST: final internal var x: T, y: T +// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic +// CHECK-AST: } diff --git a/test/AutoDiff/noderivative-attr.swift b/test/AutoDiff/noderivative-attr.swift index df6db50e2d733..951a2d4aaaa43 100644 --- a/test/AutoDiff/noderivative-attr.swift +++ b/test/AutoDiff/noderivative-attr.swift @@ -1,10 +1,10 @@ // RUN: %target-swift-frontend -typecheck -verify %s -// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}} +// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}} @noDerivative var flag: Bool struct Foo { - // expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}} + // expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}} @noDerivative var flag: Bool } diff --git a/test/Sema/Inputs/class_differentiable_other_module.swift b/test/Sema/Inputs/class_differentiable_other_module.swift new file mode 100644 index 0000000000000..1850e4673ac0d --- /dev/null +++ b/test/Sema/Inputs/class_differentiable_other_module.swift @@ -0,0 +1,9 @@ +// SWIFT_ENABLE_TENSORFLOW + +// expected-note @+1 {{type declared here}} +class OtherFileNonconforming {} + +// expected-note @+1 {{type declared here}} +class GenericOtherFileNonconforming { + var x: T +} diff --git a/test/Sema/class_differentiable.swift b/test/Sema/class_differentiable.swift new file mode 100644 index 0000000000000..d2aad2024fe75 --- /dev/null +++ b/test/Sema/class_differentiable.swift @@ -0,0 +1,570 @@ +// SWIFT_ENABLE_TENSORFLOW +// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/class_differentiable_other_module.swift + +// Verify that a `Differentiable` type upholds `AllDifferentiableVariables == TangentVector`. +func assertAllDifferentiableVariablesEqualsTangentVector(_: T.Type) + where T : Differentiable, T.AllDifferentiableVariables == T.TangentVector {} + +// Verify that a type `T` conforms to `AdditiveArithmetic`. +func assertConformsToAdditiveArithmetic(_: T.Type) where T : AdditiveArithmetic {} + +// Verify that a type `T` conforms to `ElementaryFunctions`. +func assertConformsToElementaryFunctions(_: T.Type) where T : ElementaryFunctions {} + +// Verify that a type `T` conforms to `VectorProtocol`. +func assertConformsToVectorProtocol(_: T.Type) where T : VectorProtocol {} + +// Dummy protocol with default implementations for `AdditiveArithmetic` requirements. +// Used to test `Self : AdditiveArithmetic` requirements. +protocol DummyAdditiveArithmetic : AdditiveArithmetic {} +extension DummyAdditiveArithmetic { + static func == (lhs: Self, rhs: Self) -> Bool { + fatalError() + } + static var zero: Self { + fatalError() + } + static func + (lhs: Self, rhs: Self) -> Self { + fatalError() + } + static func - (lhs: Self, rhs: Self) -> Self { + fatalError() + } +} + +class Empty : Differentiable {} +func testEmpty() { + assertConformsToAdditiveArithmetic(Empty.AllDifferentiableVariables.self) + assertConformsToAdditiveArithmetic(Empty.TangentVector.self) + assertConformsToElementaryFunctions(Empty.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(Empty.TangentVector.self) +} + +// Test structs with `let` stored properties. +// Derived conformances fail because `mutating func move` requires all stored +// properties to be mutable. +class ImmutableStoredProperties : Differentiable { + var okay: Float + + // expected-warning @+1 {{stored property 'nondiff' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + let nondiff: Int + + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + let diff: Float + + init() { + okay = 0 + nondiff = 0 + diff = 0 + } +} +func testImmutableStoredProperties() { + _ = ImmutableStoredProperties.TangentVector(okay: 1) +} +class MutableStoredPropertiesWithInitialValue : Differentiable { + var x = Float(1) + var y = Double(1) +} +// Test class with both an empty constructor and memberwise initializer. +class AllMixedStoredPropertiesHaveInitialValue : Differentiable { + let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + var y = Float(1) + // Memberwise initializer should be `init(y:)` since `x` is immutable. + static func testMemberwiseInitializer() { + _ = AllMixedStoredPropertiesHaveInitialValue() + } +} +/* +class HasCustomConstructor: Differentiable { + var x = Float(1) + var y = Float(1) + // Custom constructor should not affect synthesis. + init(x: Float, y: Float, z: Bool) {} +} +*/ + +class Simple : Differentiable { + var w: Float + var b: Float + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +func testSimple() { + let simple = Simple(w: 1, b: 1) + let tangent = Simple.TangentVector(w: 1, b: 1) + simple.move(along: tangent) +} + +// Test type with mixed members. +class Mixed : Differentiable { + var simple: Simple + var float: Float + + init(simple: Simple, float: Float) { + self.simple = simple + self.float = float + } +} +func testMixed(_ simple: Simple, _ simpleTangent: Simple.TangentVector) { + let mixed = Mixed(simple: simple, float: 1) + let tangent = Mixed.TangentVector(simple: simpleTangent, float: 1) + mixed.move(along: tangent) +} + +// Test type with manual definition of vector space types to `Self`. +final class VectorSpacesEqualSelf : Differentiable & DummyAdditiveArithmetic { + var w: Float + var b: Float + typealias TangentVector = VectorSpacesEqualSelf + typealias AllDifferentiableVariables = VectorSpacesEqualSelf + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +/* +extension VectorSpacesEqualSelf : Equatable, AdditiveArithmetic { + static func == (lhs: VectorSpacesEqualSelf, rhs: VectorSpacesEqualSelf) -> Bool { + fatalError() + } + static var zero: VectorSpacesEqualSelf { + fatalError() + } + static func + (lhs: VectorSpacesEqualSelf, rhs: VectorSpacesEqualSelf) -> VectorSpacesEqualSelf { + fatalError() + } + static func - (lhs: VectorSpacesEqualSelf, rhs: VectorSpacesEqualSelf) -> VectorSpacesEqualSelf { + fatalError() + } +} +*/ + +// Test generic type with vector space types to `Self`. +class GenericVectorSpacesEqualSelf : Differentiable + where T : Differentiable, T == T.TangentVector, + T == T.AllDifferentiableVariables +{ + var w: T + var b: T + + init(w: T, b: T) { + self.w = w + self.b = b + } +} +func testGenericVectorSpacesEqualSelf() { + let genericSame = GenericVectorSpacesEqualSelf(w: 1, b: 1) + let tangent = GenericVectorSpacesEqualSelf.TangentVector(w: 1, b: 1) + genericSame.move(along: tangent) +} + +// Test nested type. +class Nested : Differentiable { + var simple: Simple + var mixed: Mixed + var generic: GenericVectorSpacesEqualSelf + + init(simple: Simple, mixed: Mixed, generic: GenericVectorSpacesEqualSelf) { + self.simple = simple + self.mixed = mixed + self.generic = generic + } +} +func testNested( + _ simple: Simple, _ mixed: Mixed, + _ genericSame: GenericVectorSpacesEqualSelf +) { + _ = Nested(simple: simple, mixed: mixed, generic: genericSame) +} + +// Test type that does not conform to `AdditiveArithmetic` but whose members do. +// Thus, `Self` cannot be used as `TangentVector` or `TangentVector`. +// Vector space structs types must be synthesized. +// Note: it would be nice to emit a warning if conforming `Self` to +// `AdditiveArithmetic` is possible. +class AllMembersAdditiveArithmetic : Differentiable { + var w: Float + var b: Float + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +func testAllMembersAdditiveArithmetic() { + assertAllDifferentiableVariablesEqualsTangentVector(AllMembersAdditiveArithmetic.self) +} + +// Test type `AllMembersVectorProtocol` whose members conforms to `VectorProtocol`, +// in which case we should make `AllDifferentiableVariables` and `TangentVector` +// conform to `VectorProtocol`. +struct MyVector : VectorProtocol, Differentiable { + var w: Float + var b: Float + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +class AllMembersVectorProtocol : Differentiable { + var v1: MyVector + var v2: MyVector + + init(v: MyVector) { + v1 = v + v2 = v + } +} +func testAllMembersVectorProtocol() { + assertConformsToVectorProtocol(AllMembersVectorProtocol.AllDifferentiableVariables.self) + assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self) +} + +// Test type `AllMembersElementaryFunctions` whose members conforms to `ElementaryFunctions`, +// in which case we should make `AllDifferentiableVariables` and `TangentVector` +// conform to `ElementaryFunctions`. +struct MyVector2 : ElementaryFunctions, Differentiable { + var w: Float + var b: Float + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +class AllMembersElementaryFunctions : Differentiable { + var v1: MyVector2 + var v2: MyVector2 + + init(v: MyVector2) { + v1 = v + v2 = v + } +} +func testAllMembersElementaryFunctions() { + assertConformsToElementaryFunctions(AllMembersElementaryFunctions.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(AllMembersElementaryFunctions.TangentVector.self) +} + +// Test type whose properties are not all differentiable. +class DifferentiableSubset : Differentiable { + var w: Float + var b: Float + @noDerivative var flag: Bool + @noDerivative let technicallyDifferentiable: Float = .pi + + init(w: Float, b: Float, flag: Bool) { + self.w = w + self.b = b + self.flag = flag + } +} +func testDifferentiableSubset() { + assertConformsToAdditiveArithmetic(DifferentiableSubset.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(DifferentiableSubset.AllDifferentiableVariables.self) + assertConformsToVectorProtocol(DifferentiableSubset.AllDifferentiableVariables.self) + assertAllDifferentiableVariablesEqualsTangentVector(DifferentiableSubset.self) + _ = DifferentiableSubset.TangentVector(w: 1, b: 1) + _ = DifferentiableSubset.TangentVector(w: 1, b: 1) + _ = DifferentiableSubset.AllDifferentiableVariables(w: 1, b: 1) + + _ = pullback(at: DifferentiableSubset(w: 1, b: 2, flag: false)) { model in + model.w + model.b + } +} + +// Test nested type whose properties are not all differentiable. +class NestedDifferentiableSubset : Differentiable { + var x: DifferentiableSubset + var mixed: Mixed + @noDerivative var technicallyDifferentiable: Float + + init(x: DifferentiableSubset, mixed: Mixed) { + self.x = x + self.mixed = mixed + technicallyDifferentiable = 0 + } +} +func testNestedDifferentiableSubset() { + assertAllDifferentiableVariablesEqualsTangentVector(NestedDifferentiableSubset.self) +} + +// Test type that uses synthesized vector space types but provides custom +// method. +class HasCustomMethod : Differentiable { + var simple: Simple + var mixed: Mixed + var generic: GenericVectorSpacesEqualSelf + + init(simple: Simple, mixed: Mixed, generic: GenericVectorSpacesEqualSelf) { + self.simple = simple + self.mixed = mixed + self.generic = generic + } + + func move(along direction: TangentVector) { + print("Hello world") + simple.move(along: direction.simple) + mixed.move(along: direction.mixed) + generic.move(along: direction.generic) + } +} + +// Test type that conforms to `KeyPathIterable`. +// The `AllDifferentiableVariables` class should also conform to `KeyPathIterable`. +class TestKeyPathIterable : Differentiable, KeyPathIterable { + var w: Float + @noDerivative let technicallyDifferentiable: Float = .pi + + // NOTE: `KeyPathIterable` derived conformances do not yet support class + // types. + var allKeyPaths: [PartialKeyPath] { + [\TestKeyPathIterable.w, \TestKeyPathIterable.technicallyDifferentiable] + } + + init(w: Float) { + self.w = w + technicallyDifferentiable = 0 + } +} +func testKeyPathIterable(x: TestKeyPathIterable) { + _ = x.allDifferentiableVariables.allKeyPaths +} + +// Test type with user-defined memberwise initializer. +class TF_25: Differentiable { + public var bar: Float + public init(bar: Float) { + self.bar = bar + } +} +// Test user-defined memberwise initializer. +class TF_25_Generic: Differentiable { + public var bar: T + public init(bar: T) { + self.bar = bar + } +} + +// Test initializer that is not a memberwise initializer because of stored property name vs parameter label mismatch. +class HasCustomNonMemberwiseInitializer: Differentiable { + var value: T + init(randomLabel value: T) { self.value = value } +} + +// Test type with generic environment. +class HasGenericEnvironment : Differentiable { + var x: Float = 0 +} + +// Test type with generic members that conform to `Differentiable`. +class GenericSynthesizeAllStructs : Differentiable { + var w: T + var b: T + + init(w: T, b: T) { + self.w = w + self.b = b + } +} + +// Test type in generic context. +class A { + class B : Differentiable { + class InGenericContext : Differentiable { + @noDerivative var a: A + var b: B + var t: T + var u: U + + init(a: A, b: B, t: T, u: U) { + self.a = a + self.b = b + self.t = t + self.u = u + } + } + } +} + +// Test extension. +class Extended { + var x: Float + + init(x: Float) { + self.x = x + } +} +extension Extended : Differentiable {} + +// Test extension of generic type. +class GenericExtended { + var x: T + + init(x: T) { + self.x = x + } +} +extension GenericExtended : Differentiable where T : Differentiable {} + +// Test constrained extension of generic type. +class GenericConstrained { + var x: T + + init(x: T) { + self.x = x + } +} +extension GenericConstrained : Differentiable + where T : Differentiable {} + +final class TF_260 : Differentiable & DummyAdditiveArithmetic { + var x: T.TangentVector + + init(x: T.TangentVector) { + self.x = x + } +} + +// TF-269: Test crash when differentiation properties have no getter. +// Related to access levels and associated type inference. + +// TODO(TF-631): Blocked by class type differentiation support. +// [AD] Unhandled instruction in adjoint emitter: %2 = ref_element_addr %0 : $TF_269, #TF_269.filter // user: %3 +// [AD] Diagnosing non-differentiability. +// [AD] For instruction: +// %2 = ref_element_addr %0 : $TF_269, #TF_269.filter // user: %3 +/* +public protocol TF_269_Layer: Differentiable & KeyPathIterable + where AllDifferentiableVariables: KeyPathIterable { + + associatedtype Input: Differentiable + associatedtype Output: Differentiable + func applied(to input: Input) -> Output +} + +public class TF_269 : TF_269_Layer { + public var filter: Float + public typealias Activation = @differentiable (Output) -> Output + @noDerivative public let activation: @differentiable (Output) -> Output + + init(filter: Float, activation: @escaping Activation) { + self.filter = filter + self.activation = activation + } + + // NOTE: `KeyPathIterable` derived conformances do not yet support class + // types. + public var allKeyPaths: [PartialKeyPath] { + [] + } + + public func applied(to input: Float) -> Float { + return input + } +} +*/ + +// Test errors. + +// expected-error @+1 {{class 'MissingInitializer' has no initializers}} +class MissingInitializer : Differentiable { + // expected-note @+1 {{stored property 'w' without initial value prevents synthesized initializers}} + var w: Float + // expected-note @+1 {{stored property 'b' without initial value prevents synthesized initializers}} + var b: Float +} + +// Test manually customizing vector space types. +// Thees should fail. Synthesis is semantically unsupported if vector space +// types are customized. +final class TangentVectorWB : DummyAdditiveArithmetic, Differentiable { + var w: Float + var b: Float + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +// expected-error @+1 {{type 'VectorSpaceTypeAlias' does not conform to protocol 'Differentiable'}} +final class VectorSpaceTypeAlias : DummyAdditiveArithmetic, Differentiable { + var w: Float + var b: Float + typealias TangentVector = TangentVectorWB + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} +// expected-error @+1 {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}} +final class VectorSpaceCustomStruct : DummyAdditiveArithmetic, Differentiable { + var w: Float + var b: Float + struct TangentVector : AdditiveArithmetic, Differentiable { + var w: Float.TangentVector + var b: Float.TangentVector + typealias TangentVector = VectorSpaceCustomStruct.TangentVector + } + + init(w: Float, b: Float) { + self.w = w + self.b = b + } +} + +class StaticNoDerivative : Differentiable { + @noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}} +} + +final class StaticMembersShouldNotAffectAnything : DummyAdditiveArithmetic, Differentiable { + static var x: Bool = true + static var y: Bool = false +} + +class ImplicitNoDerivative : Differentiable { + var a: Float = 0 + var b: Bool = true // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} +} + +class ImplicitNoDerivativeWithSeparateTangent : Differentiable { + var x: DifferentiableSubset + var b: Bool = true // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + + init(x: DifferentiableSubset) { + self.x = x + } +} + +// TF-265: Test invalid initializer (that uses a non-existent type). +class InvalidInitializer : Differentiable { + init(filterShape: (Int, Int, Int, Int), blah: NonExistentType) {} // expected-error {{use of undeclared type 'NonExistentType'}} +} + +// Test memberwise initializer synthesis. +final class NoMemberwiseInitializerExtended { + var value: T + init(_ value: T) { + self.value = value + } +} +extension NoMemberwiseInitializerExtended: Equatable, AdditiveArithmetic, DummyAdditiveArithmetic + where T : AdditiveArithmetic {} +extension NoMemberwiseInitializerExtended: Differentiable + where T : Differentiable & AdditiveArithmetic {} + +// Test derived conformances in disallowed contexts. + +// expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} +extension OtherFileNonconforming : Differentiable {} + +// expected-error @+2 {{type 'GenericOtherFileNonconforming' does not conform to protocol 'Differentiable'}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} +extension GenericOtherFileNonconforming : Differentiable {} diff --git a/test/Sema/struct_differentiable.swift b/test/Sema/struct_differentiable.swift index 7986362380153..8d97d38dd9295 100644 --- a/test/Sema/struct_differentiable.swift +++ b/test/Sema/struct_differentiable.swift @@ -204,9 +204,11 @@ struct HasCustomMethod : Differentiable { var simple: Simple var mixed: Mixed var generic: GenericVectorSpacesEqualSelf - func moved(along: TangentVector) -> HasCustomMethod { + mutating func move(along direction: TangentVector) { print("Hello world") - return self + simple.move(along: direction.simple) + mixed.move(along: direction.mixed) + generic.move(along: direction.generic) } } @@ -330,7 +332,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable { } struct StaticNoDerivative : Differentiable { - @noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}} + @noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}} } struct StaticMembersShouldNotAffectAnything : AdditiveArithmetic, Differentiable {