diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 5533e769c4950..e44a36293edf6 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2532,6 +2532,8 @@ ERROR(parameterized_invalid_parameters_struct,none, "invalid", (Type)) ERROR(broken_additive_arithmetic_requirement,none, "AdditiveArithmetic protocol is broken: unexpected requirement", ()) +ERROR(broken_elementary_functions_requirement,none, + "ElementaryFunctions protocol is broken: unexpected requirement", ()) ERROR(broken_vector_protocol_requirement,none, "VectorProtocol protocol is broken: unexpected requirement", ()) ERROR(broken_differentiable_requirement,none, diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 7d6fd183feb07..3275a7be6e332 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -78,6 +78,7 @@ PROTOCOL(Encodable) PROTOCOL(Decodable) // SWIFT_ENABLE_TENSORFLOW PROTOCOL(AdditiveArithmetic) +PROTOCOL(ElementaryFunctions) PROTOCOL(KeyPathIterable) PROTOCOL(TensorArrayProtocol) PROTOCOL(TensorGroup) diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 4194cc7da6e7e..960eda3d07d53 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -4206,6 +4206,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::StringInterpolationProtocol: // SWIFT_ENABLE_TENSORFLOW case KnownProtocolKind::AdditiveArithmetic: + case KnownProtocolKind::ElementaryFunctions: case KnownProtocolKind::KeyPathIterable: case KnownProtocolKind::TensorArrayProtocol: case KnownProtocolKind::TensorGroup: diff --git a/lib/Sema/CMakeLists.txt b/lib/Sema/CMakeLists.txt index 4c48747492e97..82f299f76517d 100644 --- a/lib/Sema/CMakeLists.txt +++ b/lib/Sema/CMakeLists.txt @@ -28,6 +28,7 @@ add_swift_host_library(swiftSema STATIC DerivedConformanceError.cpp # SWIFT_ENABLE_TENSORFLOW DerivedConformanceAdditiveArithmetic.cpp + DerivedConformanceElementaryFunctions.cpp DerivedConformanceVectorProtocol.cpp DerivedConformanceDifferentiable.cpp DerivedConformanceKeyPathIterable.cpp diff --git a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp index 94798dc9b8e08..94ea90e06a6e5 100644 --- a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp +++ b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp @@ -183,7 +183,7 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl, memberOpExprs.push_back(createMemberOpExpr(member)); memberNames.push_back(member->getName()); } - // Call memberwise initialier with member operator call expressions. + // Call memberwise initializer with member operator call expressions. auto *callExpr = CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames); ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index ce3093d952d52..6aa3dfde5ba8c 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -568,12 +568,14 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType()); auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType()); + auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); + auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType()); auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol); auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType()); auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable); auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType()); - SmallVector<TypeLoc, 3> inherited {diffableType}; + SmallVector<TypeLoc, 3> inherited{diffableType}; // Cache original members and their associated types for later use. SmallVector<VarDecl *, 8> diffProperties; @@ -589,9 +591,16 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, bool canDeriveAdditiveArithmetic = llvm::all_of(diffProperties, [&](VarDecl *vd) { return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id), - addArithProto, parentDC, - None); - }); + addArithProto, parentDC, None); + }); + + // Associated struct can derive `ElementaryFunctions` if the associated types + // of all stored properties conform to `ElementaryFunctions`. + bool canDeriveElementaryFunctions = + llvm::all_of(diffProperties, [&](VarDecl *vd) { + return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id), + mathProto, parentDC, None); + }); // Associated struct can derive `VectorProtocol` if the associated types of // all members conform to `VectorProtocol` and share the same @@ -625,6 +634,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, None)) inherited.push_back(kpIterableType); } + // If all members conform to `ElementaryFunctions`, make the associated struct + // conform to `ElementaryFunctions`. + if (canDeriveElementaryFunctions) + inherited.push_back(mathType); // If all members also conform to `VectorProtocol` with the same `Scalar` // type, make the associated struct conform to `VectorProtocol` instead of // just `AdditiveArithmetic`. diff --git a/lib/Sema/DerivedConformanceElementaryFunctions.cpp b/lib/Sema/DerivedConformanceElementaryFunctions.cpp new file mode 100644 index 0000000000000..8ed5a6df51ca8 --- /dev/null +++ b/lib/Sema/DerivedConformanceElementaryFunctions.cpp @@ -0,0 +1,338 @@ +//===--- DerivedConformanceElementaryFunctions.cpp ------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file implements explicit derivation of the ElementaryFunctions protocol +// for struct types. +// +//===----------------------------------------------------------------------===// + +#include "CodeSynthesis.h" +#include "TypeChecker.h" +#include "swift/AST/Decl.h" +#include "swift/AST/Expr.h" +#include "swift/AST/GenericSignature.h" +#include "swift/AST/Module.h" +#include "swift/AST/ParameterList.h" +#include "swift/AST/Pattern.h" +#include "swift/AST/ProtocolConformance.h" +#include "swift/AST/Stmt.h" +#include "swift/AST/Types.h" +#include "DerivedConformances.h" + +using namespace swift; + +// Represents synthesizable `ElementaryFunction` protocol requirements. +enum ElementaryFunction { +#define ELEMENTARY_FUNCTION(ID, NAME) ID, +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION +}; + +static StringRef getElementaryFunctionName(ElementaryFunction op) { + switch (op) { +#define ELEMENTARY_FUNCTION(ID, NAME) case ElementaryFunction::ID: return NAME; +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION + } +} + +// Return the protocol requirement with the specified name. +// TODO: Move function to shared place for use with other derived conformances. +static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { + auto lookup = proto->lookupDirect(name); + llvm::erase_if(lookup, [](ValueDecl *v) { + return !isa<ProtocolDecl>(v->getDeclContext()) || + !v->isProtocolRequirement(); + }); + assert(lookup.size() == 1 && "Ambiguous protocol requirement"); + return lookup.front(); +} + +// Return true if given nominal type has a `let` stored with an initial value. +// TODO: Move function to shared place for use with other derived conformances. +static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { + return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { + return v->isLet() && v->hasInitialValue(); + }); +} + +// Return the `ElementaryFunction` protocol requirement corresponding to the +// given elementary function. +static ValueDecl *getElementaryFunctionRequirement( + ASTContext &C, ElementaryFunction op) { + auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); + auto operatorId = C.getIdentifier(getElementaryFunctionName(op)); + switch (op) { +#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ + case ID: \ + return getProtocolRequirement(mathProto, operatorId); +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION_UNARY + case Root: + return getProtocolRequirement(mathProto, operatorId); + case Pow: + case PowInt: + auto lookup = mathProto->lookupDirect(operatorId); + lookup.erase(std::remove_if(lookup.begin(), lookup.end(), + [](ValueDecl *v) { + return !isa<ProtocolDecl>( + v->getDeclContext()) || + !v->isProtocolRequirement(); + }), + lookup.end()); + assert(lookup.size() == 2 && "Expected two 'pow' functions"); + auto *powFuncDecl = cast<FuncDecl>(lookup.front()); + auto secondParamType = + powFuncDecl->getParameters()->get(1)->getInterfaceType(); + if (secondParamType->getAnyNominal() == C.getIntDecl()) + return op == PowInt ? lookup.front() : lookup[1]; + else + return op == PowInt ? lookup[1] : lookup.front(); + } +} + +// Get the effective memberwise initializer of the given nominal type, or create +// it if it does not exist. +static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer( + TypeChecker &TC, NominalTypeDecl *nominal) { + auto &C = nominal->getASTContext(); + if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer()) + return initDecl; + auto *initDecl = createImplicitConstructor( + TC, nominal, ImplicitConstructorKind::Memberwise); + nominal->addMember(initDecl); + C.addSynthesizedDecl(initDecl); + return initDecl; +} + +bool DerivedConformance::canDeriveElementaryFunctions(NominalTypeDecl *nominal, + DeclContext *DC) { + // Nominal type must be a struct. (Zero stored properties is okay.) + auto *structDecl = dyn_cast<StructDecl>(nominal); + if (!structDecl) + return false; + // Must not have any `let` stored properties with an initial value. + // - This restriction may be lifted later with support for "true" memberwise + // initializers that initialize all stored properties, including initial + // value information. + if (hasLetStoredPropertyWithInitialValue(nominal)) + return false; + // All stored properties must conform to `ElementaryFunctions`. + auto &C = nominal->getASTContext(); + auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); + return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) { + if (!v->hasInterfaceType()) + C.getLazyResolver()->resolveDeclSignature(v); + if (!v->hasInterfaceType()) + return false; + auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); + return (bool)TypeChecker::conformsToProtocol(varType, mathProto, DC, None); + }); +} + +// Synthesize body for the given `ElementaryFunction` protocol requirement. +static void deriveBodyElementaryFunction(AbstractFunctionDecl *funcDecl, + ElementaryFunction op) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + // Create memberwise initializer: `Nominal.init(...)`. + auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer(); + assert(memberwiseInitDecl && "Memberwise initializer must exist"); + auto *initDRE = + new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal, + funcDecl, /*Implicit*/ true); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr); + + // Get operator protocol requirement. + auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); + auto *operatorReq = getElementaryFunctionRequirement(C, op); + + // Create reference(s) to operator parameters: one for unary functions and two + // for binary functions. + auto params = funcDecl->getParameters(); + auto *firstParamDRE = + new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true); + Expr *secondParamDRE = nullptr; + if (params->size() == 2) + secondParamDRE = + new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true); + + // Create call expression combining lhs and rhs members using member operator. + auto createMemberOpCallExpr = [&](VarDecl *member) -> Expr * { + auto module = nominal->getModuleContext(); + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto confRef = module->lookupConformance(memberType, mathProto); + assert(confRef && "Member does not conform to math protocol"); + + // Get member type's elementary function, e.g. `Member.cos`. + // Use protocol requirement declaration for the operator by default: this + // will be dynamically dispatched. + ValueDecl *memberOpDecl = operatorReq; + // If conformance reference is concrete, then use concrete witness + // declaration for the operator. + if (confRef->isConcrete()) + memberOpDecl = confRef->getConcrete()->getWitnessDecl( + operatorReq, C.getLazyResolver()); + assert(memberOpDecl && "Member operator declaration must exist"); + auto memberOpDRE = + new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto memberOpExpr = + new (C) DotSyntaxCallExpr(memberOpDRE, SourceLoc(), memberTypeExpr); + + // - For unary ops, create expression: + // `<op>(x.member)`. + // - For `pow(_ x: Self, _ y: Self)`, create expression: + // `<op>(x.member, y.member)`. + // - For `pow(_ x: Self, _ n: Int)` and `root(_ x: Self, n: Int)`, create: + // `<op>(x.member, n)`. + Expr *firstArg = new (C) MemberRefExpr(firstParamDRE, SourceLoc(), member, + DeclNameLoc(), /*Implicit*/ true); + Expr *secondArg = nullptr; + if (secondParamDRE) { + if (op == PowInt || op == Root) + secondArg = secondParamDRE; + else + secondArg = new (C) MemberRefExpr(secondParamDRE, SourceLoc(), member, + DeclNameLoc(), /*Implicit*/ true); + } + SmallVector<Expr *, 2> memberOpArgs{firstArg}; + if (secondArg) + memberOpArgs.push_back(secondArg); + SmallVector<Identifier, 2> memberOpArgLabels(memberOpArgs.size()); + auto *memberOpCallExpr = CallExpr::createImplicit( + C, memberOpExpr, memberOpArgs, memberOpArgLabels); + return memberOpCallExpr; + }; + + // Create array of member operator call expressions. + llvm::SmallVector<Expr *, 2> memberOpCallExprs; + llvm::SmallVector<Identifier, 2> memberNames; + for (auto member : nominal->getStoredProperties()) { + memberOpCallExprs.push_back(createMemberOpCallExpr(member)); + memberNames.push_back(member->getName()); + } + // Call memberwise initializer with member operator call expressions. + auto *callExpr = + CallExpr::createImplicit(C, initExpr, memberOpCallExprs, memberNames); + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); + funcDecl->setBody( + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true)); +} + +#define ELEMENTARY_FUNCTION(ID, NAME) \ +static void deriveBodyElementaryFunctions_##ID(AbstractFunctionDecl *funcDecl, \ + void *) { \ + deriveBodyElementaryFunction(funcDecl, ID); \ +} +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION + +// Synthesize function declaration for the given math operator. +static ValueDecl *deriveElementaryFunction(DerivedConformance &derived, +ElementaryFunction op) { + auto nominal = derived.Nominal; + auto parentDC = derived.getConformanceContext(); + auto &C = derived.TC.Context; + auto selfInterfaceType = parentDC->getDeclaredInterfaceType(); + + // Create parameter declaration with the given name and type. + auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * { + auto *param = new (C) + ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + Identifier(), SourceLoc(), C.getIdentifier(name), parentDC); + param->setInterfaceType(type); + return param; + }; + + ParameterList *params = nullptr; + + switch (op) { +#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ + case ID: \ + params = \ + ParameterList::create(C, {createParamDecl("x", selfInterfaceType)}); \ + break; +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION_UNARY + case Pow: + params = + ParameterList::create(C, {createParamDecl("x", selfInterfaceType), + createParamDecl("y", selfInterfaceType)}); + break; + case PowInt: + case Root: + params = ParameterList::create( + C, {createParamDecl("x", selfInterfaceType), + createParamDecl("n", C.getIntDecl()->getDeclaredInterfaceType())}); + break; + } + + auto operatorId = C.getIdentifier(getElementaryFunctionName(op)); + DeclName operatorDeclName(C, operatorId, params); + auto operatorDecl = + FuncDecl::create(C, SourceLoc(), StaticSpellingKind::KeywordStatic, + SourceLoc(), operatorDeclName, SourceLoc(), + /*Throws*/ false, SourceLoc(), + /*GenericParams*/ nullptr, params, + TypeLoc::withoutLoc(selfInterfaceType), parentDC); + operatorDecl->setImplicit(); + switch (op) { +#define ELEMENTARY_FUNCTION(ID, NAME) \ + case ID: \ + operatorDecl->setBodySynthesizer(deriveBodyElementaryFunctions_##ID, \ + nullptr); \ + break; +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION + } + if (auto env = parentDC->getGenericEnvironmentOfContext()) + operatorDecl->setGenericEnvironment(env); + operatorDecl->computeType(); + operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); + operatorDecl->setValidationToChecked(); + + derived.addMembersToConformanceContext({operatorDecl}); + C.addSynthesizedDecl(operatorDecl); + + return operatorDecl; +} + +ValueDecl * +DerivedConformance::deriveElementaryFunctions(ValueDecl *requirement) { + // Diagnose conformances in disallowed contexts. + if (checkAndDiagnoseDisallowedContext(requirement)) + return nullptr; + // Create memberwise initializer for nominal type if it doesn't already exist. + getOrCreateEffectiveMemberwiseInitializer(TC, Nominal); +#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ + if (requirement->getBaseName() == TC.Context.getIdentifier(NAME)) \ + return deriveElementaryFunction(*this, ID); +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION_UNARY + if (requirement->getBaseName() == TC.Context.getIdentifier("root")) + return deriveElementaryFunction(*this, Root); + if (requirement->getBaseName() == TC.Context.getIdentifier("pow")) { + auto *powFuncDecl = cast<FuncDecl>(requirement); + return powFuncDecl->getParameters()->get(1)->getName().str() == "n" + ? deriveElementaryFunction(*this, PowInt) + : deriveElementaryFunction(*this, Pow); + } + TC.diagnose(requirement->getLoc(), + diag::broken_elementary_functions_requirement); + return nullptr; +} diff --git a/lib/Sema/DerivedConformanceElementaryFunctions.def b/lib/Sema/DerivedConformanceElementaryFunctions.def new file mode 100644 index 0000000000000..692ad2e0bcd15 --- /dev/null +++ b/lib/Sema/DerivedConformanceElementaryFunctions.def @@ -0,0 +1,63 @@ +//===--- DerivedConformanceElementaryFunctions.def ------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file defines macros used for macro-metaprogramming with +// ElementaryFunction protocol requirements. Currently used only by derived +// conformances. +// +//===----------------------------------------------------------------------===// + +/// ELEMENTARY_FUNCTION(Id, Name) +/// - Id is an elementary function identifier, used for the enum case +/// `ElementaryFunctions::Id`. +/// - Name is the name of the elementary function. + +// One macro must be defined by the includer. +#if !defined(ELEMENTARY_FUNCTION) && !defined(ELEMENTARY_FUNCTION_UNARY) +#error "Macro must be defined by includer" +#endif + +#ifndef ELEMENTARY_FUNCTION +#define ELEMENTARY_FUNCTION(Id, Name) +#endif + +#ifndef ELEMENTARY_FUNCTION_UNARY +#define ELEMENTARY_FUNCTION_UNARY(Id, Name) ELEMENTARY_FUNCTION(Id,Name) +#endif + +ELEMENTARY_FUNCTION_UNARY(Sqrt, "sqrt") +ELEMENTARY_FUNCTION_UNARY(Cos, "cos") +ELEMENTARY_FUNCTION_UNARY(Sin, "sin") +ELEMENTARY_FUNCTION_UNARY(Tan, "tan") +ELEMENTARY_FUNCTION_UNARY(Cosh, "cosh") +ELEMENTARY_FUNCTION_UNARY(Sinh, "sinh") +ELEMENTARY_FUNCTION_UNARY(Tanh, "tanh") +ELEMENTARY_FUNCTION_UNARY(Acos, "acos") +ELEMENTARY_FUNCTION_UNARY(Asin, "asin") +ELEMENTARY_FUNCTION_UNARY(Atan, "atan") +ELEMENTARY_FUNCTION_UNARY(Acosh, "acosh") +ELEMENTARY_FUNCTION_UNARY(Asinh, "asinh") +ELEMENTARY_FUNCTION_UNARY(Atanh, "atanh") +ELEMENTARY_FUNCTION_UNARY(Exp, "exp") +ELEMENTARY_FUNCTION_UNARY(Exp2, "exp2") +ELEMENTARY_FUNCTION_UNARY(Exp10, "exp10") +ELEMENTARY_FUNCTION_UNARY(Expm1, "expm1") +ELEMENTARY_FUNCTION_UNARY(Log, "log") +ELEMENTARY_FUNCTION_UNARY(Log2, "log2") +ELEMENTARY_FUNCTION_UNARY(Log10, "log10") +ELEMENTARY_FUNCTION_UNARY(Log1p, "log1p") +ELEMENTARY_FUNCTION(Pow, "pow") +ELEMENTARY_FUNCTION(PowInt, "pow") +ELEMENTARY_FUNCTION(Root, "root") + +#undef ELEMENTARY_FUNCTION_UNARY +#undef ELEMENTARY_FUNCTION diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 8bca52eaaf71e..6ab43cec8a4a7 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -62,6 +62,14 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, return canDeriveHashable(Nominal); } + // SWIFT_ENABLE_TENSORFLOW + if (*knownProtocol == KnownProtocolKind::AdditiveArithmetic) + return canDeriveAdditiveArithmetic(Nominal, DC); + + // SWIFT_ENABLE_TENSORFLOW + if (*knownProtocol == KnownProtocolKind::ElementaryFunctions) + return canDeriveElementaryFunctions(Nominal, DC); + // SWIFT_ENABLE_TENSORFLOW if (*knownProtocol == KnownProtocolKind::KeyPathIterable) return canDeriveKeyPathIterable(Nominal); @@ -73,10 +81,6 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, // SWIFT_ENABLE_TENSORFLOW if (*knownProtocol == KnownProtocolKind::TensorGroup) return canDeriveTensorGroup(Nominal, DC); - - // SWIFT_ENABLE_TENSORFLOW - if (*knownProtocol == KnownProtocolKind::AdditiveArithmetic) - return canDeriveAdditiveArithmetic(Nominal, DC); // SWIFT_ENABLE_TENSORFLOW if (*knownProtocol == KnownProtocolKind::VectorProtocol) @@ -187,7 +191,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // Local function that retrieves the requirement with the same name as // the provided requirement, but within the given known protocol. - auto getRequirement = [&](KnownProtocolKind kind) -> ValueDecl * { + // SWIFT_ENABLE_TENSORFLOW + auto getRequirement = [&](KnownProtocolKind kind, + llvm::function_ref<bool(ValueDecl *)> filter = + nullptr) -> ValueDecl * { // Dig out the protocol. auto proto = ctx.getProtocol(kind); if (!proto) return nullptr; @@ -203,6 +210,14 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // Retrieve the requirement. auto results = proto->lookupDirect(name); + // SWIFT_ENABLE_TENSORFLOW + // Filter requirements, if `filter` function is specified. + if (filter) { + llvm::erase_if(results, [&](ValueDecl *v) { + return !isa<ProtocolDecl>(v->getDeclContext()) || + !v->isProtocolRequirement() || !filter(v); + }); + } return results.empty() ? nullptr : results.front(); }; @@ -294,6 +309,34 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, return getRequirement(KnownProtocolKind::AdditiveArithmetic); } + // SWIFT_ENABLE_TENSORFLOW + // ElementaryFunctions requirements + if (name.isCompoundName()) { + auto argumentNames = name.getArgumentNames(); + if (argumentNames.size() == 1 && (false +#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) || name.getBaseName() == NAME +#include "DerivedConformanceElementaryFunctions.def" +#undef ELEMENTARY_FUNCTION_UNARY + )) { + return getRequirement(KnownProtocolKind::ElementaryFunctions); + } + if (argumentNames.size() == 2) { + if (name.getBaseName() == "root") + return getRequirement(KnownProtocolKind::ElementaryFunctions); + if (name.getBaseName() == "pow") { + return getRequirement( + KnownProtocolKind::ElementaryFunctions, + [&](ValueDecl *v) { + auto *funcDecl = dyn_cast<FuncDecl>(v); + if (!funcDecl) + return false; + return funcDecl->getParameters()->get(1)->getName() == + func->getParameters()->get(1)->getName(); + }); + } + } + } + // SWIFT_ENABLE_TENSORFLOW // VectorProtocol.scaled(by:) if (name.isCompoundName() && name.getBaseName() == ctx.Id_scaled) { diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index fb9822962ccd4..04e52aef87720 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -243,6 +243,19 @@ class DerivedConformance { /// \returns the derived member, which will also be added to the type. ValueDecl *deriveAdditiveArithmetic(ValueDecl *requirement); + // SWIFT_ENABLE_TENSORFLOW + /// Determine if an ElementaryFunctions requirement can be derived for a + /// type. + /// + /// \returns True if the requirement can be derived. + static bool canDeriveElementaryFunctions(NominalTypeDecl *type, + DeclContext *DC); + + /// Derive an ElementaryFunctions requirement for a nominal type. + /// + /// \returns the derived member, which will also be added to the type. + ValueDecl *deriveElementaryFunctions(ValueDecl *requirement); + /// Determine if a VectorProtocol requirement can be derived for a type. /// /// \returns True if the requirement can be derived. diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 207d0450cde88..0a3c9036e573e 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -5346,6 +5346,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC, case KnownProtocolKind::AdditiveArithmetic: return derived.deriveAdditiveArithmetic(Requirement); + // SWIFT_ENABLE_TENSORFLOW + case KnownProtocolKind::ElementaryFunctions: + return derived.deriveElementaryFunctions(Requirement); + // SWIFT_ENABLE_TENSORFLOW case KnownProtocolKind::VectorProtocol: return derived.deriveVectorProtocol(Requirement); diff --git a/stdlib/public/core/MathFunctions.swift.gyb b/stdlib/public/core/MathFunctions.swift.gyb index d2244b20c1627..2af6c5d03a81c 100644 --- a/stdlib/public/core/MathFunctions.swift.gyb +++ b/stdlib/public/core/MathFunctions.swift.gyb @@ -34,7 +34,8 @@ import SwiftShims /// ElementaryFunctions and FloatingPoint. /// /// [elfn]: http://en.wikipedia.org/wiki/Elementary_function -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) +// SWIFT_ENABLE_TENSORFLOW +// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) public protocol ElementaryFunctions { %for func in ElementaryFunctions: @@ -125,7 +126,8 @@ extension ${Self}: ElementaryFunctions { % end %end -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) +// SWIFT_ENABLE_TENSORFLOW +// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) extension SIMD where Scalar: ElementaryFunctions { % for func in ElementaryFunctions: @@ -168,6 +170,7 @@ extension SIMD where Scalar: ElementaryFunctions { } %for n in [2,3,4,8,16,32,64]: -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) +// SWIFT_ENABLE_TENSORFLOW +// @available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) extension SIMD${n}: ElementaryFunctions where Scalar: ElementaryFunctions { } %end diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index e0450010c2363..568a692b27612 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -40,7 +40,7 @@ struct TestNoDerivative : Differentiable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorProtocol +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol // CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables @@ -54,7 +54,7 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorProtocol +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, ElementaryFunctions, VectorProtocol // CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables diff --git a/test/Sema/Inputs/struct_elementary_functions_other_module.swift b/test/Sema/Inputs/struct_elementary_functions_other_module.swift new file mode 100644 index 0000000000000..6d9c65ea5091a --- /dev/null +++ b/test/Sema/Inputs/struct_elementary_functions_other_module.swift @@ -0,0 +1,14 @@ +// SWIFT_ENABLE_TENSORFLOW + +// expected-note @+1 24 {{type declared here}} +struct OtherFileNonconforming : Equatable { + let float: Float + var double: Double +} + +// expected-note @+1 24 {{type declared here}} +struct GenericOtherFileNonconforming<T : ElementaryFunctions> : Equatable { + let x: T + var float: Float + var double: Double +} diff --git a/test/Sema/struct_differentiable.swift b/test/Sema/struct_differentiable.swift index 5ab2573b0e11e..7986362380153 100644 --- a/test/Sema/struct_differentiable.swift +++ b/test/Sema/struct_differentiable.swift @@ -8,12 +8,18 @@ func assertAllDifferentiableVariablesEqualsTangentVector<T>(_: T.Type) // Verify that a type `T` conforms to `AdditiveArithmetic`. func assertConformsToAdditiveArithmetic<T>(_: T.Type) where T : AdditiveArithmetic {} +// Verify that a type `T` conforms to `ElementaryFunctions`. +func assertConformsToElementaryFunctions<T>(_: T.Type) where T : ElementaryFunctions {} + // Verify that a type `T` conforms to `VectorProtocol`. func assertConformsToVectorProtocol<T>(_: T.Type) where T : VectorProtocol {} struct Empty : Differentiable {} func testEmpty() { assertConformsToAdditiveArithmetic(Empty.AllDifferentiableVariables.self) + assertConformsToAdditiveArithmetic(Empty.TangentVector.self) + assertConformsToElementaryFunctions(Empty.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(Empty.TangentVector.self) } // Test interaction with `AdditiveArithmetic` derived conformances. @@ -130,19 +136,35 @@ func testAllMembersAdditiveArithmetic() { } // Test type `AllMembersVectorProtocol` whose members conforms to `VectorProtocol`, -// in which case we should make `TangentVector` and `TangentVector` conform to -// `VectorProtocol`. +// in which case we should make `AllDifferentiableVariables` and `TangentVector` +// conform to `VectorProtocol`. struct MyVector : VectorProtocol, Differentiable { var w: Float var b: Float } struct AllMembersVectorProtocol : Differentiable { - var w: MyVector - var b: MyVector + var v1: MyVector + var v2: MyVector } func testAllMembersVectorProtocol() { + assertConformsToVectorProtocol(AllMembersVectorProtocol.AllDifferentiableVariables.self) assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self) - assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self) +} + +// Test type `AllMembersElementaryFunctions` whose members conforms to `ElementaryFunctions`, +// in which case we should make `AllDifferentiableVariables` and `TangentVector` +// conform to `ElementaryFunctions`. +struct MyVector2 : ElementaryFunctions, Differentiable { + var w: Float + var b: Float +} +struct AllMembersElementaryFunctions : Differentiable { + var v1: MyVector2 + var v2: MyVector2 +} +func testAllMembersElementaryFunctions() { + assertConformsToElementaryFunctions(AllMembersElementaryFunctions.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(AllMembersElementaryFunctions.TangentVector.self) } // Test type whose properties are not all differentiable. @@ -154,6 +176,7 @@ struct DifferentiableSubset : Differentiable { } func testDifferentiableSubset() { assertConformsToAdditiveArithmetic(DifferentiableSubset.AllDifferentiableVariables.self) + assertConformsToElementaryFunctions(DifferentiableSubset.AllDifferentiableVariables.self) assertConformsToVectorProtocol(DifferentiableSubset.AllDifferentiableVariables.self) assertAllDifferentiableVariablesEqualsTangentVector(DifferentiableSubset.self) _ = DifferentiableSubset.TangentVector(w: 1, b: 1) diff --git a/test/Sema/struct_elementary_functions.swift b/test/Sema/struct_elementary_functions.swift new file mode 100644 index 0000000000000..1c0bbe1ba883d --- /dev/null +++ b/test/Sema/struct_elementary_functions.swift @@ -0,0 +1,91 @@ +// SWIFT_ENABLE_TENSORFLOW +// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_elementary_functions_other_module.swift + +struct Empty : ElementaryFunctions {} +func testEmpty() { + _ = Empty() +} + +struct Float2: ElementaryFunctions { + let a: Float + var b: Float +} +func testFloat2() { + _ = Float2(a: 1, b: 1) +} + +// Test generic type. +struct Vector2<T : ElementaryFunctions>: ElementaryFunctions { + let x: T + var y: T +} +func testVector2() { + _ = Vector2<Double>(x: 1, y: 1) +} + +// Test nested type. +struct Nested: ElementaryFunctions { + let float2: Float2 + var float: Float +} +func testNested(float2: Float2) { + _ = Nested(float2: float2, float: 1) +} + +// Test mixed type. +struct Mixed: ElementaryFunctions { + let nested: Nested + var float = Float(1) + var double: Double +} +func testMixed(nested: Nested) { + _ = Mixed(nested: nested, float: 1, double: 1) +} + +// Test type in generic context. +struct A<T> { + struct B<U, V> { + struct GenericContextNested : ElementaryFunctions { + var nested: Nested + let float: Float + var double = Double(2) + } + } +} +func testGenericContext<T, U, V>(nested: Nested) -> A<T>.B<U, V>.GenericContextNested { + A<T>.B<U, V>.GenericContextNested(nested: nested, float: 1, double: 1) +} + +// Test extension. +struct Extended { + var x: Float +} +extension Extended : ElementaryFunctions {} + +// Test extension of generic type. +struct GenericExtended<T> { + var x: T +} +extension GenericExtended : ElementaryFunctions where T : ElementaryFunctions {} + +// Test memberwise initializer synthesis. +struct NoMemberwiseInitializer<T : ElementaryFunctions> : ElementaryFunctions { + var value: T + init(randomLabel value: T) { self.value = value } +} +struct NoMemberwiseInitializerExtended<T> { + var value: T + init(_ value: T) { + self.value = value + } +} +extension NoMemberwiseInitializerExtended: ElementaryFunctions + where T : ElementaryFunctions {} + +// Test derived conformances in disallowed contexts. + +// expected-error @+1 24 {{implementation of 'ElementaryFunctions' cannot be automatically synthesized in an extension in a different file to the type}} +extension OtherFileNonconforming : ElementaryFunctions {} + +// expected-error @+1 24 {{implementation of 'ElementaryFunctions' cannot be automatically synthesized in an extension in a different file to the type}} +extension GenericOtherFileNonconforming : ElementaryFunctions {} diff --git a/test/stdlib/ElementaryFunctions.swift.gyb b/test/stdlib/ElementaryFunctions.swift.gyb new file mode 100644 index 0000000000000..ccc9710f78cff --- /dev/null +++ b/test/stdlib/ElementaryFunctions.swift.gyb @@ -0,0 +1,102 @@ +//===--- ElementaryFunctions.swift.gyb ------------------------*- swift -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// SWIFT_ENABLE_TENSORFLOW +// Runtime tests for ElementaryFunctions derived conformances. +//===----------------------------------------------------------------------===// +// -*- swift -*- +// RUN: %empty-directory(%t) +// RUN: %gyb %s -o %t/tgmath.swift +// RUN: %line-directive %t/tgmath.swift -- %target-build-swift %t/tgmath.swift -o %t/a.out +// RUN: %target-codesign %t/a.out +// RUN: %line-directive %t/tgmath.swift -- %target-run %t/a.out +// REQUIRES: executable_test + +#if (arch(i386) || arch(x86_64)) && !os(Windows) + typealias TestLiteralType = Float80 +#else + typealias TestLiteralType = Double +#endif + +import StdlibUnittest + +let MathTests = TestSuite("Math") + +func expectEqualWithNaNEquality<T>( + _ expected: [T], _ actual: [T], file: String = #file, line: UInt = #line +) where T: BinaryFloatingPoint { + for (x, y) in zip(expected, actual) { + expectTrue(x == y || x.isNaN && y.isNaN, + "\(x) != \(y) for \(T.self).", + file: file, line: line) + } +} + +%from SwiftMathFunctions import * + +struct Wrapper<T: ElementaryFunctions & Equatable>: ElementaryFunctions & Equatable { + var x, y: T + var values: [T] { [x, y] } +} + +@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) +internal extension ElementaryFunctions where Self: BinaryFloatingPoint { + static func elementaryFunctionTests() { + let values: [Self] = [-0.375, 0.375] + let wrapper = Wrapper<Self>(x: values[0], y: values[1]) + + expectEqualWithNaNEquality(values.map(Self.acos), Wrapper<Self>.acos(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.asin), Wrapper<Self>.asin(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.atan), Wrapper<Self>.atan(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.cos), Wrapper<Self>.cos(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.sin), Wrapper<Self>.sin(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.tan), Wrapper<Self>.tan(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.acosh), Wrapper<Self>.acosh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.asinh), Wrapper<Self>.asinh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.atanh), Wrapper<Self>.atanh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.cosh), Wrapper<Self>.cosh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.sinh), Wrapper<Self>.sinh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.tanh), Wrapper<Self>.tanh(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.exp), Wrapper<Self>.exp(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.exp2), Wrapper<Self>.exp2(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.exp10), Wrapper<Self>.exp10(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.expm1), Wrapper<Self>.expm1(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.log), Wrapper<Self>.log(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.log2), Wrapper<Self>.log2(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.log10), Wrapper<Self>.log10(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.log1p), Wrapper<Self>.log1p(wrapper).values) + expectEqualWithNaNEquality(values.map(Self.sqrt), Wrapper<Self>.sqrt(wrapper).values) + expectEqualWithNaNEquality(values.map { x in Self.root(x, 3) }, Wrapper<Self>.root(wrapper, 3).values) + expectEqualWithNaNEquality(values.map { x in Self.pow(x, x) }, Wrapper<Self>.pow(wrapper, wrapper).values) + expectEqualWithNaNEquality(values.map { x in Self.pow(x, 3) }, Wrapper<Self>.pow(wrapper, 3).values) + } +} + +%for T in ['Float', 'Double', 'CGFloat', 'Float80']: +% if T == 'Float80': +#if (arch(i386) || arch(x86_64)) && !os(Windows) +% elif T == 'CGFloat': +#if canImport(CoreGraphics) + import CoreGraphics +% end + +MathTests.test("${T}") { + if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + ${T}.elementaryFunctionTests() + } +} + +% if T in ['CGFloat', 'Float80']: +#endif +% end +%end + +runAllTests()