Skip to content

Derive ElementaryFunctions conformances for structs. #25500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2532,6 +2532,8 @@ ERROR(parameterized_invalid_parameters_struct,none,
"invalid", (Type))
ERROR(broken_additive_arithmetic_requirement,none,
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
ERROR(broken_elementary_functions_requirement,none,
"ElementaryFunctions protocol is broken: unexpected requirement", ())
ERROR(broken_vector_protocol_requirement,none,
"VectorProtocol protocol is broken: unexpected requirement", ())
ERROR(broken_differentiable_requirement,none,
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ PROTOCOL(Encodable)
PROTOCOL(Decodable)
// SWIFT_ENABLE_TENSORFLOW
PROTOCOL(AdditiveArithmetic)
PROTOCOL(ElementaryFunctions)
PROTOCOL(KeyPathIterable)
PROTOCOL(TensorArrayProtocol)
PROTOCOL(TensorGroup)
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::StringInterpolationProtocol:
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::ElementaryFunctions:
case KnownProtocolKind::KeyPathIterable:
case KnownProtocolKind::TensorArrayProtocol:
case KnownProtocolKind::TensorGroup:
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_swift_host_library(swiftSema STATIC
DerivedConformanceError.cpp
# SWIFT_ENABLE_TENSORFLOW
DerivedConformanceAdditiveArithmetic.cpp
DerivedConformanceElementaryFunctions.cpp
DerivedConformanceVectorProtocol.cpp
DerivedConformanceDifferentiable.cpp
DerivedConformanceKeyPathIterable.cpp
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/DerivedConformanceAdditiveArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
memberOpExprs.push_back(createMemberOpExpr(member));
memberNames.push_back(member->getName());
}
// Call memberwise initialier with member operator call expressions.
// Call memberwise initializer with member operator call expressions.
auto *callExpr =
CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames);
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
Expand Down
21 changes: 17 additions & 4 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,14 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType());
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType());
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType());
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());

SmallVector<TypeLoc, 3> inherited {diffableType};
SmallVector<TypeLoc, 3> inherited{diffableType};

// Cache original members and their associated types for later use.
SmallVector<VarDecl *, 8> diffProperties;
Expand All @@ -589,9 +591,16 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
bool canDeriveAdditiveArithmetic =
llvm::all_of(diffProperties, [&](VarDecl *vd) {
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
addArithProto, parentDC,
None);
});
addArithProto, parentDC, None);
});

// Associated struct can derive `ElementaryFunctions` if the associated types
// of all stored properties conform to `ElementaryFunctions`.
bool canDeriveElementaryFunctions =
llvm::all_of(diffProperties, [&](VarDecl *vd) {
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
mathProto, parentDC, None);
});

// Associated struct can derive `VectorProtocol` if the associated types of
// all members conform to `VectorProtocol` and share the same
Expand Down Expand Up @@ -625,6 +634,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
None))
inherited.push_back(kpIterableType);
}
// If all members conform to `ElementaryFunctions`, make the associated struct
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic (conform synthesized Differentiable associated types to ElementaryFunctions if possible) is not super principled. Similar logic exists for conforming synthesized types to VectorProtocol if possible.

We could remove this logic by constraining TangentVector to VectorProtocol and/or ElementaryFunctions, but it's not clear to me whether that's desirable. Constraining TangentVector to VectorProtocol seems more reasonable.

// conform to `ElementaryFunctions`.
if (canDeriveElementaryFunctions)
inherited.push_back(mathType);
// If all members also conform to `VectorProtocol` with the same `Scalar`
// type, make the associated struct conform to `VectorProtocol` instead of
// just `AdditiveArithmetic`.
Expand Down
Loading