Skip to content

[Sema] Differentiable conformance derivation for class types. #25914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2869,9 +2869,9 @@ ERROR(compiler_evaluable_ref_non_compiler_evaluable,none,
"@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ())

// @noDerivative attribute
ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
"'@noDerivative' is only allowed on stored properties in structure types "
"that declare a conformance to 'Differentiable'", ())
ERROR(noderivative_only_on_differentiable_struct_or_class_fields,none,
"'@noDerivative' is only allowed on stored properties in structure or "
"class types that declare a conformance to 'Differentiable'", ())

//------------------------------------------------------------------------------
// MARK: Type Check Expressions
Expand Down
29 changes: 13 additions & 16 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// SWIFT_ENABLE_TENSORFLOW
//
// This file implements explicit derivation of the Differentiable protocol for
// struct types.
// struct and class types.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -108,8 +108,7 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) {
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
assert(diffableProto && "`Differentiable` protocol not found");
auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(),
diffableProto,
DC, None);
diffableProto, DC, None);
assert(conf && "Nominal must conform to `Differentiable`");
Type assocType = conf->getTypeWitnessByName(DC->getSelfTypeInContext(), id);
assert(assocType && "`Differentiable` protocol associated type not found");
Expand All @@ -120,9 +119,8 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) {

bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
DeclContext *DC) {
// Nominal type must be a struct. (Zero stored properties is okay.)
auto *structDecl = dyn_cast<StructDecl>(nominal);
if (!structDecl)
// Nominal type must be a struct or class. (No stored properties is okay.)
if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))
return false;
auto &C = nominal->getASTContext();
auto *lazyResolver = C.getLazyResolver();
Expand Down Expand Up @@ -153,8 +151,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
// `X == X.TangentVector`.
if (nominal->isImplicit() && structDecl == nominal->getDeclContext() &&
TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(),
diffableProto, DC,
None))
diffableProto, DC, None))
return structDecl;
// 3. Equal nominal (and conform to `AdditiveArithmetic` if flag is true).
if (structDecl == nominal) {
Expand Down Expand Up @@ -199,7 +196,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
// initializers that initialize all stored properties, including initial
// value information.
SmallVector<VarDecl *, 16> diffProperties;
getStoredPropertiesForDifferentiation(structDecl, DC, diffProperties);
getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
return llvm::all_of(diffProperties, [&](VarDecl *v) {
if (!v->hasInterfaceType())
lazyResolver->resolveDeclSignature(v);
Expand Down Expand Up @@ -325,7 +322,8 @@ static ValueDecl *deriveDifferentiable_method(
/*Throws*/ false, SourceLoc(),
/*GenericParams=*/nullptr, params,
TypeLoc::withoutLoc(returnType), parentDC);
funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
if (!nominal->getSelfClassDecl())
funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
funcDecl->setImplicit();
funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);

Expand Down Expand Up @@ -804,6 +802,8 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
DeclContext* DC) {
auto *diffableProto =
TC.Context.getProtocol(KnownProtocolKind::Differentiable);
bool nominalCanDeriveAdditiveArithmetic =
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
for (auto *vd : nominal->getStoredProperties()) {
if (!vd->hasInterfaceType())
TC.resolveDeclSignature(vd);
Expand All @@ -814,8 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
continue;
// Check whether to diagnose stored property.
bool conformsToDifferentiable =
TC.conformsToProtocol(varType, diffableProto, nominal,
None).hasValue();
TC.conformsToProtocol(varType, diffableProto, nominal, None).hasValue();
// If stored property should not be diagnosed, continue.
if (conformsToDifferentiable && !vd->isLet())
continue;
Expand All @@ -829,8 +828,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
// `Differentiable` protocol requirements all have default implementations
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
// derived conformances will no longer be necessary.
bool nominalCanDeriveAdditiveArithmetic =
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
if (!conformsToDifferentiable) {
TC.diagnose(loc,
diag::differentiable_nondiff_type_implicit_noderivative_fixit,
Expand All @@ -844,7 +841,6 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
vd->getName(), nominal->getName(),
nominalCanDeriveAdditiveArithmetic)
.fixItInsert(loc, "@noDerivative ");

}
}

Expand Down Expand Up @@ -954,6 +950,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties;

// Check conditions for returning `Self`.
// - `Self` is not a class type.
// - No `@noDerivative` stored properties exist.
// - All stored properties must have specified associated type equal to
// `Self`.
Expand All @@ -971,7 +968,7 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
parentDC, None);

// Return `Self` if conditions are met.
if (!hasNoDerivativeStoredProp &&
if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() &&
(id == C.Id_AllDifferentiableVariables ||
(allMembersAssocTypeEqualsSelf && nominalConformsToAddArith))) {
auto selfType = parentDC->getSelfTypeInContext();
Expand Down
6 changes: 4 additions & 2 deletions lib/Sema/DerivedConformanceRingMathProtocols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
// If conformance reference is concrete, then use concrete witness
// declaration for the operator.
if (confRef->isConcrete())
memberOpDecl = confRef->getConcrete()->getWitnessDecl(
operatorReq, C.getLazyResolver());
if (auto *concreteMemberMethodDecl =
confRef->getConcrete()->getWitnessDecl(operatorReq,
C.getLazyResolver()))
memberOpDecl = concreteMemberMethodDecl;
assert(memberOpDecl && "Member operator declaration must exist");
auto memberOpDRE =
new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true);
Expand Down
8 changes: 7 additions & 1 deletion lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ DerivedConformance::declareDerivedPropertySetter(TypeChecker &tc,
/*GenericParams*/ nullptr, params, TypeLoc(), parentDC);
setterDecl->setImplicit();
setterDecl->setStatic(isStatic);
setterDecl->setSelfAccessKind(SelfAccessKind::Mutating);
// Set mutating if parent is not a class.
if (!parentDC->getSelfClassDecl())
setterDecl->setSelfAccessKind(SelfAccessKind::Mutating);

// If this is supposed to be a final method, mark it as such.
assert(isFinal || !parentDC->getSelfClassDecl());
Expand Down Expand Up @@ -584,6 +586,10 @@ DerivedConformance::declareDerivedProperty(Identifier name,
VarDecl *propDecl = new (C) VarDecl(/*IsStatic*/isStatic, VarDecl::Specifier::Var,
/*IsCaptureList*/false, SourceLoc(), name,
parentDC);
// SWIFT_ENABLE_TENSORFLOW
// TODO: Upstream this change to master.
if (isFinal && parentDC->getSelfClassDecl())
propDecl->getAttrs().add(new (C) FinalAttr(/*Implicit*/ true));
propDecl->setImplicit();
propDecl->copyFormalAccessFrom(Nominal, /*sourceIsParentContext*/ true);
propDecl->setInterfaceType(propertyInterfaceType);
Expand Down
14 changes: 7 additions & 7 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3759,20 +3759,20 @@ void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
return;
if (!vd || vd->isStatic()) {
diagnoseAndRemoveAttr(attr,
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
auto *structDecl = dyn_cast<StructDecl>(vd->getDeclContext());
if (!structDecl) {
auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl();
if (!nominal || (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))) {
diagnoseAndRemoveAttr(attr,
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
if (!conformsToDifferentiable(
structDecl->getDeclaredInterfaceType(),
structDecl->getDeclContext())) {
nominal->getDeclaredInterfaceType(),
nominal->getDeclContext())) {
diagnoseAndRemoveAttr(attr,
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public struct Foo : Differentiable {
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables

// CHECK-SILGEN-LABEL: // Foo.a.getter
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float

struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
var a: Float
Expand Down Expand Up @@ -82,9 +82,6 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
var x: T.TangentVector
}

// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized.
// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized.

// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST: internal var x: T.TangentVector
// CHECK-AST: internal init(x: T.TangentVector)
Expand All @@ -105,3 +102,32 @@ extension ConditionallyDifferentiable : Differentiable where T : Differentiable
// CHECK-AST: public var x: T
// CHECK-AST: internal init(x: T)
// CHECK-AST: }

// Verify that `TangentVector` is not synthesized to be `Self` for
// `AdditiveArithmetic`-conforming classes.
final class AdditiveArithmeticClass<T : AdditiveArithmetic & Differentiable> : AdditiveArithmetic, Differentiable {
var x, y: T
init(x: T, y: T) {
self.x = x
self.y = y
}

// Dummy `AdditiveArithmetic` requirements.
static func == (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Bool {
fatalError()
}
static var zero: AdditiveArithmeticClass {
fatalError()
}
static func + (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
fatalError()
}
static func - (lhs: AdditiveArithmeticClass, rhs: AdditiveArithmeticClass) -> Self {
fatalError()
}
}

// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
// CHECK-AST: final internal var x: T, y: T
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
// CHECK-AST: }
4 changes: 2 additions & 2 deletions test/AutoDiff/noderivative-attr.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: %target-swift-frontend -typecheck -verify %s

// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
@noDerivative var flag: Bool

struct Foo {
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
// expected-error @+1 {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
@noDerivative var flag: Bool
}

Expand Down
9 changes: 9 additions & 0 deletions test/Sema/Inputs/class_differentiable_other_module.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SWIFT_ENABLE_TENSORFLOW

// expected-note @+1 {{type declared here}}
class OtherFileNonconforming {}

// expected-note @+1 {{type declared here}}
class GenericOtherFileNonconforming<T : Differentiable> {
var x: T
}
Loading