diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 2cef56e9e9ed3..abd0d72a2c5a5 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor); /// Similar to `getDistributedSerializationRequirementType`, however, from the /// perspective of a concrete function. This way we're able to get the /// serialization requirement for specific members, also in protocols. -Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member); +Type getSerializationRequirementTypesForMember( + ValueDecl *member, llvm::SmallPtrSet &serializationRequirements); /// Get specific 'SerializationRequirement' as defined in 'nominal' /// type, which must conform to the passed 'protocol' which is expected @@ -91,19 +92,13 @@ llvm::SmallPtrSet getDistributedSerializationRequirementProtocols( NominalTypeDecl *decl, ProtocolDecl* protocol); -/// Desugar and flatten the `SerializationRequirement` type into a set of -/// specific protocol declarations. -llvm::SmallPtrSet -flattenDistributedSerializationTypeToRequiredProtocols( - TypeBase *serializationRequirement); - /// Check if the `allRequirements` represent *exactly* the /// `Encodable & Decodable` (also known as `Codable`) requirement. /// /// If so, we can emit slightly nicer diagnostics. bool checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements); + Type type); /// Get the `SerializationRequirement`, explode it into the specific /// protocol requirements and insert them into `requirements`. @@ -120,15 +115,6 @@ getDistributedSerializationRequirements( ProtocolDecl *protocol, llvm::SmallPtrSetImpl &requirementProtos); -/// Given any set of generic requirements, locate those which are about the -/// `SerializationRequirement`. Those need to be applied in the parameter and -/// return type checking of distributed targets. -llvm::SmallPtrSet -extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements); - } -// ==== ------------------------------------------------------------------------ - #endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */ diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 2a896409226b4..7d3c457e148a2 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member) llvm_unreachable("Unable to fetch ActorSystem type!"); } -Type swift::getConcreteReplacementForMemberSerializationRequirement( - ValueDecl *member) { +Type swift::getSerializationRequirementTypesForMember( + ValueDecl *member, + llvm::SmallPtrSet &serializationRequirements) { auto &C = member->getASTContext(); auto *DC = member->getDeclContext(); auto DA = C.getDistributedActorDecl(); @@ -106,8 +107,10 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl()); } - /// === Maybe the value is declared in a protocol? - if (auto protocol = DC->getSelfProtocolDecl()) { + auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) + ->getDeclaredInterfaceType(); + + if (DC->getSelfProtocolDecl()) { GenericSignature signature; if (auto *genericContext = member->getAsGenericContext()) { signature = genericContext->getGenericSignature(); @@ -115,8 +118,10 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } - auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) - ->getDeclaredInterfaceType(); + // Also store all `SerializationRequirement : SomeProtocol` requirements + for (auto proto: signature->getRequiredProtocols(SerReqAssocType)) { + serializationRequirements.insert(proto); + } // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. @@ -178,13 +183,7 @@ Type swift::getDistributedActorSystemResultHandlerType( auto module = system->getParentModule(); Type selfType = system->getSelfInterfaceType(); auto conformance = module->lookupConformance(selfType, DAS); - auto witness = - conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler); - if (auto alias = dyn_cast(witness.getPointer())) { - return alias->getDecl()->getUnderlyingType(); - } else { - return witness; - } + return conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler); } Type swift::getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system) { @@ -346,63 +345,39 @@ swift::getDistributedSerializationRequirements( if (existentialRequirementTy->isAny()) return true; // we're done here, any means there are no requirements - if (!existentialRequirementTy->isExistentialType()) { - // SerializationRequirement must be an existential type - return false; - } - - ExistentialType *serialReqType = existentialRequirementTy - ->castTo(); + auto *serialReqType = existentialRequirementTy->getAs(); if (!serialReqType || serialReqType->hasError()) { return false; } - auto desugaredTy = serialReqType->getConstraintType()->getDesugaredType(); - auto flattenedRequirements = - flattenDistributedSerializationTypeToRequiredProtocols( - desugaredTy); - for (auto p : flattenedRequirements) { + auto layout = serialReqType->getExistentialLayout(); + for (auto p : layout.getProtocols()) { requirementProtos.insert(p); } return true; } -llvm::SmallPtrSet -swift::flattenDistributedSerializationTypeToRequiredProtocols( - TypeBase *serializationRequirement) { - llvm::SmallPtrSet serializationReqs; - if (auto composition = - serializationRequirement->getAs()) { - for (auto member : composition->getMembers()) { - if (auto comp = member->getAs()) { - for (auto protocol : - flattenDistributedSerializationTypeToRequiredProtocols(comp)) { - serializationReqs.insert(protocol); - } - } else if (auto *protocol = member->getAs()) { - serializationReqs.insert(protocol->getDecl()); - } - } - } else { - auto protocol = serializationRequirement->castTo()->getDecl(); - serializationReqs.insert(protocol); - } - - return serializationReqs; -} - bool swift::checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements) { + Type type) { + if (!type) + return false; + + if (type->hasError()) + return false; + auto encodable = C.getProtocol(KnownProtocolKind::Encodable); auto decodable = C.getProtocol(KnownProtocolKind::Decodable); - if (allRequirements.size() != 2) + auto layout = type->getExistentialLayout(); + auto protocols = layout.getProtocols(); + + if (protocols.size() != 2) return false; - return allRequirements.count(encodable) && - allRequirements.count(decodable); + return std::count(protocols.begin(), protocols.end(), encodable) == 1 && + std::count(protocols.begin(), protocols.end(), decodable) == 1; } /******************************************************************************/ @@ -571,25 +546,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) // --- Check requirement: conforms_to: Act DistributedActor auto actorReq = requirements[0]; - auto distActorTy = C.getProtocol(KnownProtocolKind::DistributedActor) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (actorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!actorReq.getSecondType()->isEqual(distActorTy)) { + if (!actorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::DistributedActor)) { return false; } // --- Check requirement: conforms_to: Err Error auto errorReq = requirements[1]; - auto errorTy = C.getProtocol(KnownProtocolKind::Error) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (errorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!errorReq.getSecondType()->isEqual(errorTy)) { + if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) { return false; } @@ -604,10 +573,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) assert(ResParam && "Non void function, yet no Res generic parameter found"); if (auto func = dyn_cast(this)) { auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ResParam->getInterfaceType()->getMetatypeInstanceType()); + ResParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { return false; @@ -803,12 +771,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const // the of the RemoteCallArgument auto remoteCallArgValueGenericTy = - mapTypeIntoContext(argGenericParams[0]->getInterfaceType()) - ->getDesugaredType() - ->getMetatypeInstanceType(); + mapTypeIntoContext(argGenericParams[0]->getDeclaredInterfaceType()); // expected (the from the recordArgument) auto expectedGenericParamTy = mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) { return false; @@ -938,11 +904,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con // ... auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { @@ -1052,13 +1017,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons // --- Check requirement: conforms_to: Err Error auto errorReq = requirements[0]; - auto errorTy = C.getProtocol(KnownProtocolKind::Error) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (errorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!errorReq.getSecondType()->isEqual(errorTy)) { + if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) { return false; } @@ -1145,10 +1107,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c // --- Check: Argument: SerializationRequirement GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0]; auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { return false; @@ -1243,11 +1204,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const // === Check generic parameters in detail // --- Check: Argument: SerializationRequirement GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0]; - auto argumentType = func->mapTypeIntoContext(valueParam->getInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + auto argumentType = func->mapTypeIntoContext( + valueParam->getInterfaceType()->getMetatypeInstanceType()); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!argumentType->isEqual(resultParamType)) { return false; @@ -1268,50 +1228,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -llvm::SmallPtrSet -swift::extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements) { - llvm::SmallPtrSet serializationReqs; - auto DA = C.getDistributedActorDecl(); - auto daSerializationReqAssocType = - DA->getAssociatedType(C.Id_SerializationRequirement); - auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType(); - - for (auto req : allRequirements) { - if (req.getSecondType()->isAny()) { - continue; - } - if (!req.getFirstType()->hasDependentMember()) - continue; - - if (auto dependentMemberType = - req.getFirstType()->castTo()) { - auto dependentTy = - dependentMemberType->getAssocType()->getInterfaceType(); - - if (dependentTy->isEqual(daSystemSerializationReqTy)) { - auto requirementProto = req.getSecondType(); - if (auto proto = dyn_cast_or_null( - requirementProto->getAnyNominal())) { - serializationReqs.insert(proto); - } else { - auto serialReqType = requirementProto->castTo() - ->getConstraintType() - ->getDesugaredType(); - auto flattenedRequirements = - flattenDistributedSerializationTypeToRequiredProtocols( - serialReqType); - for (auto p : flattenedRequirements) { - serializationReqs.insert(p); - } - } - } - } - } - - return serializationReqs; -} - /******************************************************************************/ /********************** Distributed Functions *********************************/ /******************************************************************************/ diff --git a/lib/Sema/CodeSynthesisDistributedActor.cpp b/lib/Sema/CodeSynthesisDistributedActor.cpp index 66b1ead965f2c..16ff8a8b0af23 100644 --- a/lib/Sema/CodeSynthesisDistributedActor.cpp +++ b/lib/Sema/CodeSynthesisDistributedActor.cpp @@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator, if (!distributedTarget->isDistributed()) return nullptr; } - assert(distributedTarget); + // This evaluation type-check by now was already computed and cached; + // We need to check in order to avoid emitting a THUNK for a distributed func + // which had errors; as the thunk then may also cause un-addressable issues and confusion. + if (swift::checkDistributedFunction(distributedTarget)) { + return nullptr; + } + auto &C = distributedTarget->getASTContext(); if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 9aa62497aecb4..a6790c9253c3a 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -2067,7 +2067,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) { return (prop && prop->isFinal() && isa(prop->getDeclContext()) && - cast(prop->getDeclContext())->isActor() && + cast(prop->getDeclContext())->isAnyActor() && !prop->isStatic() && prop->getName() == ctx.Id_unownedExecutor && prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl()); diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 9623da08fd84b..e4b6624856510 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -385,10 +385,18 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, - const llvm::SmallPtrSetImpl &serializationRequirements, + Type serializationRequirement, + llvm::SmallPtrSet serializationRequirements, bool diagnose) { auto &C = valueDecl->getASTContext(); + if (serializationRequirement && serializationRequirement->hasError()) { + return false; + } + if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) { + return false; // error of the type would be diagnosed elsewhere + } + Type resultType; if (auto func = dyn_cast(valueDecl)) { resultType = func->mapTypeIntoContext(func->getResultInterfaceType()); @@ -401,18 +409,27 @@ static bool checkDistributedTargetResultType( if (resultType->isVoid()) return false; + + // Collect extra "SerializationRequirement: SomeProtocol" requirements + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto s: srl.getProtocols()) { + serializationRequirements.insert(s); + } + } + auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); + C, serializationRequirement); - for(auto serializationReq : serializationRequirements) { + for (auto serializationReq: serializationRequirements) { auto conformance = TypeChecker::conformsToProtocol(resultType, serializationReq, module); if (conformance.isInvalid()) { if (diagnose) { llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); auto diag = valueDecl->diagnose( diag::distributed_actor_target_result_not_codable, @@ -427,12 +444,12 @@ static bool checkDistributedTargetResultType( } } } // end if: diagnose - + return true; } } - return false; + return false; } bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) { @@ -496,66 +513,42 @@ bool CheckDistributedFunctionRequest::evaluate( } auto &C = func->getASTContext(); - auto DC = func->getDeclContext(); auto module = func->getParentModule(); /// If no distributed module is available, then no reason to even try checks. if (!C.getLoadedModule(C.Id_Distributed)) return true; - // === All parameters and the result type must conform - // SerializationRequirement llvm::SmallPtrSet serializationRequirements; - if (auto extension = dyn_cast(DC)) { - serializationRequirements = extractDistributedSerializationRequirements( - C, extension->getGenericRequirements()); - } else if (auto actor = dyn_cast(DC)) { - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (isa(DC)) { - if (auto seqReqTy = - getConcreteReplacementForMemberSerializationRequirement(func)) { - auto seqReqTyDes = seqReqTy->castTo()->getConstraintType()->getDesugaredType(); - for (auto req : flattenDistributedSerializationTypeToRequiredProtocols(seqReqTyDes)) { - serializationRequirements.insert(req); - } - } - - // The distributed actor constrained protocol has no serialization requirements - // or actor system defined, so these will only be enforced, by implementations - // of DAs conforming to it, skip checks here. - if (serializationRequirements.empty()) { - return false; - } - } else { - llvm_unreachable("Distributed function detected in type other than extension, " - "distributed actor, or protocol! This should not be possible " - ", please file a bug."); - } - - // If the requirement is exactly `Codable` we diagnose it ia bit nicer. - auto serializationRequirementIsCodable = - checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for (auto param : *func->getParameters()) { - // --- Check parameters for 'Codable' conformance - auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); - - for (auto req : serializationRequirements) { - if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { - auto diag = func->diagnose( - diag::distributed_actor_func_param_not_codable, - param->getArgumentName().str(), param->getInterfaceType(), - func->getDescriptiveKind(), - serializationRequirementIsCodable ? "Codable" - : req->getNameStr()); - - if (auto paramNominalTy = paramTy->getAnyNominal()) { - addCodableFixIt(paramNominalTy, diag); - } // else, no nominal type to suggest the fixit for, e.g. a closure - return true; + Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements); + + for (auto param: *func->getParameters()) { + // --- Check the parameter conforming to serialization requirements + if (serializationReqType && !serializationReqType->hasError()) { + // If the requirement is exactly `Codable` we diagnose it ia bit nicer. + auto serializationRequirementIsCodable = + checkDistributedSerializationRequirementIsExactlyCodable( + C, serializationReqType); + + // --- Check parameters for 'SerializationRequirement' conformance + auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); + + auto srl = serializationReqType->getExistentialLayout(); + for (auto req: srl.getProtocols()) { + if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { + auto diag = func->diagnose( + diag::distributed_actor_func_param_not_codable, + param->getArgumentName().str(), param->getInterfaceType(), + func->getDescriptiveKind(), + serializationRequirementIsCodable ? "Codable" + : req->getNameStr()); + + if (auto paramNominalTy = paramTy->getAnyNominal()) { + addCodableFixIt(paramNominalTy, diag); + } // else, no nominal type to suggest the fixit for, e.g. a closure + + return true; + } } } @@ -592,9 +585,10 @@ bool CheckDistributedFunctionRequest::evaluate( } } - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationRequirements, - /*diagnose=*/true)) { + // --- Result type must be either void or a serialization requirement conforming type + if (checkDistributedTargetResultType( + module, func, serializationReqType, serializationRequirements, + /*diagnose=*/true)) { return true; } @@ -648,8 +642,11 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { systemDecl, C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = + getSerializationRequirementTypesForMember(systemVar, serializationRequirements); + auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) { return true; } @@ -749,13 +746,14 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal (void)nominal->getDistributedActorIDProperty(); } -void TypeChecker::checkDistributedFunc(FuncDecl *func) { +bool TypeChecker::checkDistributedFunc(FuncDecl *func) { if (!func->isDistributed()) - return; + return false; - swift::checkDistributedFunction(func); + return swift::checkDistributedFunction(func); } +// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks llvm::SmallPtrSet swift::getDistributedSerializationRequirementProtocols( NominalTypeDecl *nominal, ProtocolDecl *protocol) { @@ -768,11 +766,13 @@ swift::getDistributedSerializationRequirementProtocols( return {}; } - auto serialReqType = - ty->castTo()->getConstraintType()->getDesugaredType(); - // TODO(distributed): check what happens with Any - return flattenDistributedSerializationTypeToRequiredProtocols(serialReqType); + auto layout = ty->getExistentialLayout(); + llvm::SmallPtrSet result; + for (auto p : layout.getProtocols()) { + result.insert(p); + } + return result; } ConstructorDecl* @@ -896,8 +896,7 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator, continue; auto paramTy = genericParamList->getParams()[0] - ->getInterfaceType() - ->getMetatypeInstanceType(); + ->getDeclaredInterfaceType(); // `decodeNextArgument` should return its generic parameter value if (!FD->getResultInterfaceType()->isEqual(paramTy)) @@ -905,20 +904,16 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator, // Let's find out how many serialization requirements does this method cover // e.g. `Codable` is two requirements - `Encodable` and `Decodable`. - unsigned numSerializationReqsCovered = llvm::count_if( - FD->getGenericRequirements(), [&](const Requirement &requirement) { - if (!(requirement.getFirstType()->isEqual(paramTy) && - requirement.getKind() == RequirementKind::Conformance)) - return 0; - - return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0; - }); + bool okay = llvm::all_of(serializationReqs, + [&](ProtocolDecl *p) -> bool { + return FD->getGenericSignature()->requiresProtocol(paramTy, p); + }); // If the current method covers all of the serialization requirements, // it's a match. Note that it might also have other requirements, but // we let that go as long as there are no two candidates that differ // only in generic requirements. - if (numSerializationReqsCovered == serializationReqs.size()) + if (okay) candidates.push_back(FD); } diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index abde198ae90d8..54d46182b6605 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -2778,6 +2778,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval, // So, build out the body now. ASTScope::expandFunctionBody(AFD); + if (AFD->isDistributedThunk()) { + if (auto func = dyn_cast(AFD)) { + if (TypeChecker::checkDistributedFunc(func)) { + return errorBody(); + } + } + } + // Type check the function body if needed. bool hadError = false; if (!alreadyTypeChecked) { diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index fa77dae7e5967..5bff84aa1ad75 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1132,7 +1132,9 @@ diagnosePotentialUnavailability(SourceRange ReferenceRange, void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl); /// Type check a single 'distributed func' declaration. -void checkDistributedFunc(FuncDecl *func); +/// +/// Returns `true` if there was an error. +bool checkDistributedFunc(FuncDecl *func); bool checkAvailability(SourceRange ReferenceRange, AvailabilityContext Availability, diff --git a/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift new file mode 100644 index 0000000000000..785accbb2e58d --- /dev/null +++ b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift @@ -0,0 +1,41 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift +// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/Inputs/FakeDistributedActorSystems.swift 2> %t/output.txt || echo 'failed expectedly' +// RUN: %FileCheck %s < %t/output.txt + +// REQUIRES: concurrency +// REQUIRES: distributed + +// rdar://76038845 +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +import Distributed + +// Notes: +// This test specifically is not just a -typecheck -verify test but attempts to generate the whole module. +// This is because we may be emitting errors but otherwise still attempt to emit a thunk for an "error-ed" +// distributed function, which would then crash in later phases of compilation when we try to get types +// of the `func` the THUNK is based on. + +typealias DefaultDistributedActorSystem = LocalTestingDistributedActorSystem + +distributed actor Service { +} + +extension Service { + distributed func boombox(_ id: Box) async throws {} + // CHECK: parameter '' of type 'Box' in distributed instance method does not conform to serialization requirement 'Codable' + + distributed func boxIt() async throws -> Box { fatalError() } + // CHECK: result type 'Box' of distributed instance method 'boxIt' does not conform to serialization requirement 'Codable' +} + +public enum Box: Hashable { case boom } + +@main struct Main { + static func main() async { + try? await Service(actorSystem: .init()).boombox(Box.boom) + } +} + diff --git a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift index af4fa1020ce58..ee3ce1d570797 100644 --- a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift +++ b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift @@ -82,6 +82,13 @@ extension NoSerializationRequirementYet extension NoSerializationRequirementYet where SerializationRequirement: Codable { + // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}} + distributed func test4() -> NotCodable { + .init() + } +} + +extension ProtocolWithChecksSeqReqDA { // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}} distributed func test4() -> NotCodable { .init()