Skip to content

[SILGen] Fix the type of closure thunks that are passed const reference structs #82486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "swift/AST/FunctionRefInfo.h"
#include "swift/AST/ProtocolConformanceRef.h"
#include "swift/AST/ThrownErrorDestination.h"
#include "swift/AST/Type.h"
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Debug.h"
#include "swift/Basic/InlineBitfield.h"
Expand Down Expand Up @@ -3367,9 +3368,8 @@ class UnresolvedTypeConversionExpr : public ImplicitConversionExpr {
/// FIXME: This should be a CapturingExpr.
class FunctionConversionExpr : public ImplicitConversionExpr {
public:
FunctionConversionExpr(Expr *subExpr, Type type)
: ImplicitConversionExpr(ExprKind::FunctionConversion, subExpr, type) {}

FunctionConversionExpr(Expr *subExpr, Type type);

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::FunctionConversion;
}
Expand Down Expand Up @@ -4290,6 +4290,12 @@ class ClosureExpr : public AbstractClosureExpr {
/// The body of the closure.
BraceStmt *Body;

/// Used when lowering ClosureExprs to C function pointers.
/// This is required to access the ClangType from SILDeclRef.
/// TODO: this will be redundant after we preserve ClangTypes
/// in the canonical types.
FunctionConversionExpr *ConvertedTo;

friend class GlobalActorAttributeRequest;

bool hasNoGlobalActorAttribute() const {
Expand All @@ -4301,19 +4307,19 @@ class ClosureExpr : public AbstractClosureExpr {
}

public:
ClosureExpr(const DeclAttributes &attributes,
SourceRange bracketRange, VarDecl *capturedSelfDecl,
ParameterList *params, SourceLoc asyncLoc, SourceLoc throwsLoc,
TypeExpr *thrownType, SourceLoc arrowLoc, SourceLoc inLoc,
TypeExpr *explicitResultType, DeclContext *parent)
: AbstractClosureExpr(ExprKind::Closure, Type(), /*Implicit=*/false,
parent),
Attributes(attributes), BracketRange(bracketRange),
CapturedSelfDecl(capturedSelfDecl),
AsyncLoc(asyncLoc), ThrowsLoc(throwsLoc), ArrowLoc(arrowLoc),
InLoc(inLoc), ThrownType(thrownType),
ExplicitResultTypeAndBodyState(explicitResultType, BodyState::Parsed),
Body(nullptr) {
ClosureExpr(const DeclAttributes &attributes, SourceRange bracketRange,
VarDecl *capturedSelfDecl, ParameterList *params,
SourceLoc asyncLoc, SourceLoc throwsLoc, TypeExpr *thrownType,
SourceLoc arrowLoc, SourceLoc inLoc, TypeExpr *explicitResultType,
DeclContext *parent)
: AbstractClosureExpr(ExprKind::Closure, Type(), /*Implicit=*/false,
parent),
Attributes(attributes), BracketRange(bracketRange),
CapturedSelfDecl(capturedSelfDecl), AsyncLoc(asyncLoc),
ThrowsLoc(throwsLoc), ArrowLoc(arrowLoc), InLoc(inLoc),
ThrownType(thrownType),
ExplicitResultTypeAndBodyState(explicitResultType, BodyState::Parsed),
Body(nullptr), ConvertedTo(nullptr) {
setParameterList(params);
Bits.ClosureExpr.HasAnonymousClosureVars = false;
Bits.ClosureExpr.ImplicitSelfCapture = false;
Expand Down Expand Up @@ -4528,6 +4534,9 @@ class ClosureExpr : public AbstractClosureExpr {
ExplicitResultTypeAndBodyState.setInt(v);
}

const FunctionConversionExpr *getConvertedTo() const { return ConvertedTo; }
void setConvertedTo(FunctionConversionExpr *e) { ConvertedTo = e; }

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::Closure;
}
Expand Down
12 changes: 7 additions & 5 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ namespace llvm {
class raw_ostream;
}

namespace clang {
class Type;
}

namespace swift {
enum class EffectsKind : uint8_t;
class AbstractFunctionDecl;
Expand Down Expand Up @@ -261,11 +265,9 @@ struct SILDeclRef {
/// for the containing ClassDecl.
/// - If 'loc' is a global VarDecl, this returns its GlobalAccessor
/// SILDeclRef.
explicit SILDeclRef(
Loc loc,
bool isForeign = false,
bool isDistributed = false,
bool isDistributedLocal = false);
explicit SILDeclRef(Loc loc, bool isForeign = false,
bool isDistributed = false,
bool isDistributedLocal = false);

/// See above put produces a prespecialization according to the signature.
explicit SILDeclRef(Loc loc, GenericSignature prespecializationSig);
Expand Down
10 changes: 10 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,16 @@ DestructureTupleExpr::create(ASTContext &ctx,
srcExpr, dstExpr, ty);
}

FunctionConversionExpr::FunctionConversionExpr(Expr *subExpr, Type type)
: ImplicitConversionExpr(ExprKind::FunctionConversion, subExpr, type) {
while (auto *PE = dyn_cast<ParenExpr>(subExpr))
subExpr = PE->getSubExpr();
if (auto *CLE = dyn_cast<CaptureListExpr>(subExpr))
subExpr = CLE->getClosureBody();
if (auto *CE = dyn_cast<ClosureExpr>(subExpr))
CE->setConvertedTo(this);
}

SourceRange TupleExpr::getSourceRange() const {
auto start = LParenLoc;
if (start.isInvalid()) {
Expand Down
25 changes: 19 additions & 6 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
//
//===----------------------------------------------------------------------===//

#include "swift/AST/Expr.h"
#include "swift/AST/Type.h"
#define DEBUG_TYPE "libsil"

#include "swift/AST/AnyFunctionRef.h"
Expand Down Expand Up @@ -4321,12 +4323,10 @@ static CanSILFunctionType getUncachedSILFunctionTypeForConstant(
// The type of the native-to-foreign thunk for a swift closure.
if (constant.isForeign && constant.hasClosureExpr() &&
shouldStoreClangType(TC.getDeclRefRepresentation(constant))) {
auto clangType = TC.Context.getClangFunctionType(
origLoweredInterfaceType->getParams(),
origLoweredInterfaceType->getResult(),
FunctionTypeRepresentation::CFunctionPointer);
AbstractionPattern pattern =
AbstractionPattern(origLoweredInterfaceType, clangType);
assert(!extInfoBuilder.getClangTypeInfo().empty() &&
"clang type not found");
AbstractionPattern pattern = AbstractionPattern(
origLoweredInterfaceType, extInfoBuilder.getClangTypeInfo().getType());
return getSILFunctionTypeForAbstractCFunction(
TC, pattern, origLoweredInterfaceType, extInfoBuilder, constant);
}
Expand Down Expand Up @@ -4834,9 +4834,22 @@ getAbstractionPatternForConstant(ASTContext &ctx, SILDeclRef constant,
if (!constant.isForeign)
return AbstractionPattern(fnType);

if (const auto *closure = dyn_cast_or_null<ClosureExpr>(
constant.loc.dyn_cast<AbstractClosureExpr *>())) {
if (const auto *convertedTo = closure->getConvertedTo()) {
auto clangInfo = convertedTo->getType()
->castTo<AnyFunctionType>()
->getExtInfo()
.getClangTypeInfo();
if (!clangInfo.empty())
return AbstractionPattern(fnType, clangInfo.getType());
}
}

auto bridgedFn = getBridgedFunction(constant);
if (!bridgedFn)
return AbstractionPattern(fnType);

const clang::Decl *clangDecl = bridgedFn->getClangDecl();
if (!clangDecl)
return AbstractionPattern(fnType);
Expand Down
10 changes: 8 additions & 2 deletions lib/SILGen/SILGenBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,8 +1315,9 @@ static SILValue emitObjCUnconsumedArgument(SILGenFunction &SGF,
SILLocation loc,
SILValue arg) {
auto &lowering = SGF.getTypeLowering(arg->getType());
// If address-only, make a +1 copy and operate on that.
if (lowering.isAddressOnly() && SGF.useLoweredAddresses()) {
// If arg is non-trivial and has an address type, make a +1 copy and operate
// on that.
if (!lowering.isTrivial() && arg->getType().isAddress()) {
auto tmp = SGF.emitTemporaryAllocation(loc, arg->getType().getObjectType());
SGF.B.createCopyAddr(loc, arg, tmp, IsNotTake, IsInitialization);
return tmp;
Expand Down Expand Up @@ -1453,6 +1454,11 @@ emitObjCThunkArguments(SILGenFunction &SGF, SILLocation loc, SILDeclRef thunk,
auto buf = SGF.emitTemporaryAllocation(loc, native.getType());
native.forwardInto(SGF, loc, buf);
native = SGF.emitManagedBufferWithCleanup(buf);
} else if (!fnConv.isSILIndirect(nativeInputs[i]) &&
native.getType().isAddress()) {
// Load the value if the argument has an address type and the native
// function expects the argument to be passed directly.
native = SGF.emitManagedLoadCopy(loc, native.getValue());
}

if (nativeInputs[i].isConsumedInCaller()) {
Expand Down
30 changes: 21 additions & 9 deletions lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,19 @@ static ManagedValue convertCFunctionSignature(SILGenFunction &SGF,
FunctionConversionExpr *e,
SILType loweredResultTy,
llvm::function_ref<ManagedValue ()> fnEmitter) {
SILType loweredDestTy = SGF.getLoweredType(e->getType());
SILType loweredDestTy;
auto destTy = e->getType();
auto clangInfo =
destTy->castTo<AnyFunctionType>()->getExtInfo().getClangTypeInfo();
if (clangInfo.empty())
loweredDestTy = SGF.getLoweredType(destTy);
else
// This won't be necessary after we stop dropping clang types when
// canonicalizing function types.
loweredDestTy = SGF.getLoweredType(
AbstractionPattern(destTy->getCanonicalType(), clangInfo.getType()),
destTy);

ManagedValue result;

// We're converting between C function pointer types. They better be
Expand Down Expand Up @@ -1804,20 +1816,20 @@ ManagedValue emitCFunctionPointer(SILGenFunction &SGF,
#endif
semanticExpr = conv->getSubExpr()->getSemanticsProvidingExpr();
}

if (auto declRef = dyn_cast<DeclRefExpr>(semanticExpr)) {
setLocFromConcreteDeclRef(declRef->getDeclRef());
} else if (auto memberRef = dyn_cast<MemberRefExpr>(semanticExpr)) {
setLocFromConcreteDeclRef(memberRef->getMember());
} else if (isAnyClosureExpr(semanticExpr)) {
(void) emitAnyClosureExpr(SGF, semanticExpr,
[&](AbstractClosureExpr *closure) {
// Emit the closure body.
SGF.SGM.emitClosure(closure, SGF.getClosureTypeInfo(closure));
(void)emitAnyClosureExpr(
SGF, semanticExpr, [&](AbstractClosureExpr *closure) {
// Emit the closure body.
SGF.SGM.emitClosure(closure, SGF.getClosureTypeInfo(closure));

loc = closure;
return ManagedValue();
});
loc = closure;
return ManagedValue();
});
} else {
llvm_unreachable("c function pointer converted from a non-concrete decl ref");
}
Expand Down
3 changes: 0 additions & 3 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8495,9 +8495,6 @@ class SwiftToClangBasicReader :

llvm::Expected<const clang::Type *>
ModuleFile::getClangType(ClangTypeID TID) {
if (!getContext().LangOpts.UseClangFunctionTypes)
return nullptr;

if (TID == 0)
return nullptr;

Expand Down
53 changes: 47 additions & 6 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Decl.h"
#include "swift/AST/DiagnosticsCommon.h"
#include "swift/AST/DiagnosticsSema.h"
#include "swift/AST/Expr.h"
Expand All @@ -41,6 +42,7 @@
#include "swift/AST/SynthesizedFileUnit.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/TypeVisitor.h"
#include "swift/AST/Types.h"
#include "swift/Basic/Assertions.h"
#include "swift/Basic/Defer.h"
#include "swift/Basic/FileSystem.h"
Expand Down Expand Up @@ -5546,6 +5548,31 @@ static TypeAliasDecl *findTypeAliasForBuiltin(ASTContext &Ctx, Type T) {
return cast<TypeAliasDecl>(CurModuleResults[0]);
}

namespace {
struct ImplementationOnylWalker : TypeWalker {
bool hadImplementationOnlyDecl = false;
const ModuleDecl *currentModule;
ImplementationOnylWalker(const ModuleDecl *M) : currentModule(M) {}
Action walkToTypePre(Type ty) override {
if (auto *typeAlias = dyn_cast<TypeAliasType>(ty)) {
if (importedImplementationOnly(typeAlias->getDecl()))
return Action::Stop;
} else if (auto *nominal = ty->getAs<NominalType>()) {
if (importedImplementationOnly(nominal->getDecl()))
return Action::Stop;
}
return Action::Continue;
}
bool importedImplementationOnly(const Decl *D) {
if (currentModule->isImportedImplementationOnly(D->getModuleContext())) {
hadImplementationOnlyDecl = true;
return true;
}
return false;
}
};
} // namespace

class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
Serializer &S;

Expand Down Expand Up @@ -5899,10 +5926,19 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
using namespace decls_block;

auto resultType = S.addTypeRef(fnTy->getResult());
auto clangType =
S.getASTContext().LangOpts.UseClangFunctionTypes
? S.addClangTypeRef(fnTy->getClangTypeInfo().getType())
: ClangTypeID(0);
bool shouldSerializeClangType = true;
if (S.hadImplementationOnlyImport && S.M) {
// Deserializing clang types from implementation only modules could crash
// as the transitive clang module might not be available to retrieve the
// declarations from.
ImplementationOnylWalker walker{S.M};
Type(const_cast<FunctionType *>(fnTy)).walk(walker);
if (walker.hadImplementationOnlyDecl)
shouldSerializeClangType = false;
}
auto clangType = shouldSerializeClangType
? S.addClangTypeRef(fnTy->getClangTypeInfo().getType())
: ClangTypeID(0);

auto isolation = encodeIsolation(fnTy->getIsolation());

Expand Down Expand Up @@ -6984,8 +7020,13 @@ void Serializer::writeAST(ModuleOrSourceFile DC) {
nextFile->getTopLevelDeclsWithAuxiliaryDecls(fileDecls);

for (auto D : fileDecls) {
if (isa<ImportDecl>(D) || isa<MacroExpansionDecl>(D) ||
isa<TopLevelCodeDecl>(D) || isa<UsingDecl>(D)) {
if (const auto *ID = dyn_cast<ImportDecl>(D)) {
if (ID->getAttrs().hasAttribute<ImplementationOnlyAttr>())
hadImplementationOnlyImport = true;
continue;
}
if (isa<MacroExpansionDecl>(D) || isa<TopLevelCodeDecl>(D) ||
isa<UsingDecl>(D)) {
continue;
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Serialization/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class Serializer : public SerializerBase {
/// an error in the AST.
bool hadError = false;

bool hadImplementationOnlyImport = false;

/// Helper for serializing entities in the AST block object graph.
///
/// Keeps track of assigning IDs to newly-seen entities, and collecting
Expand Down
13 changes: 13 additions & 0 deletions test/Interop/Cxx/class/Inputs/closure.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ struct NonTrivial {
int *p;
};

struct Trivial {
int i;
};

void cfunc(void (^ _Nonnull block)(NonTrivial)) noexcept {
block(NonTrivial());
}
Expand Down Expand Up @@ -75,4 +79,13 @@ inline void releaseSharedRef(SharedRef *_Nonnull x) {
}
}

void cfuncConstRefNonTrivial(void (*_Nonnull)(const NonTrivial &));
void cfuncConstRefTrivial(void (*_Nonnull)(const Trivial &));
void blockConstRefNonTrivial(void (^_Nonnull)(const NonTrivial &));
void blockConstRefTrivial(void (^_Nonnull)(const Trivial &));
#if __OBJC__
void cfuncConstRefStrong(void (*_Nonnull)(const ARCStrong &));
void blockConstRefStrong(void (^_Nonnull)(const ARCStrong &));
#endif

#endif // __CLOSURE__
Loading