Skip to content

Commit 4f5c75a

Browse files
hlopkogribozavrscentini
authored
[cxx-interop] Instantiate C++ class templates from Swift (#33284)
This PR makes it possible to instantiate C++ class templates from Swift. Given a C++ header: ```c++ // C++ module `ClassTemplates` template<class T> struct MagicWrapper { T t; }; struct MagicNumber {}; ``` it is now possible to write in Swift: ```swift import ClassTemplates func x() -> MagicWrapper<MagicNumber> { return MagicWrapper<MagicNumber>() } ``` This is achieved by importing C++ class templates as generic structs, and then when Swift type checker calls `applyGenericArguments` we detect when the generic struct is backed by the C++ class template and call Clang to instantiate the template. In order to make it possible to put class instantiations such as `MagicWrapper<MagicNumber>` into Swift signatures, we have created a new field in `StructDecl` named `TemplateInstantiationType` where the typechecker stores the `BoundGenericType` which we serialize. Deserializer then notices that the `BoundGenericType` is actually a C++ class template and performs the instantiation logic. Depends on #33420. Progress towards https://bugs.swift.org/browse/SR-13261. Fixes https://bugs.swift.org/browse/SR-13775. Co-authored-by: Dmitri Gribenko <[email protected]> Co-authored-by: Rosica Dejanovska <[email protected]>
1 parent c4cfcc7 commit 4f5c75a

35 files changed

+520
-40
lines changed

include/swift/AST/ClangModuleLoader.h

+10
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,16 @@ class ClangModuleLoader : public ModuleLoader {
177177
StringRef relatedEntityKind,
178178
llvm::function_ref<void(TypeDecl *)> receiver) = 0;
179179

180+
/// Instantiate and import class template using given arguments.
181+
///
182+
/// This method will find the clang::ClassTemplateSpecialization decl if
183+
/// it already exists, or it will create one. Then it will import this
184+
/// decl the same way as we import typedeffed class templates - using
185+
/// the hidden struct prefixed with `__CxxTemplateInst`.
186+
virtual StructDecl *
187+
instantiateCXXClassTemplate(clang::ClassTemplateDecl *decl,
188+
ArrayRef<clang::TemplateArgument> arguments) = 0;
189+
180190
/// Try to parse the string as a Clang function type.
181191
///
182192
/// Returns null if there was a parsing failure.

include/swift/AST/Decl.h

+39
Original file line numberDiff line numberDiff line change
@@ -3402,6 +3402,42 @@ class EnumDecl final : public NominalTypeDecl {
34023402
class StructDecl final : public NominalTypeDecl {
34033403
SourceLoc StructLoc;
34043404

3405+
// We import C++ class templates as generic structs. Then when in Swift code
3406+
// we want to substitude generic parameters with actual arguments, we
3407+
// convert the arguments to C++ equivalents and ask Clang to instantiate the
3408+
// C++ template. Then we import the C++ class template instantiation
3409+
// as a non-generic structs with a name prefixed with `__CxxTemplateInst`.
3410+
//
3411+
// To reiterate:
3412+
// 1) We need to have a C++ class template declaration in the Clang AST. This
3413+
// declaration is simply imported from a Clang module.
3414+
// 2) We need a Swift generic struct in the Swift AST. This will provide
3415+
// template arguments to Clang.
3416+
// 3) We produce a C++ class template instantiation in the Clang AST
3417+
// using 1) and 2). This declaration does not exist in the Clang module
3418+
// AST initially in the general case, it's added there on instantiation.
3419+
// 4) We import the instantiation as a Swift struct, with the name prefixed
3420+
// with `__CxxTemplateInst`.
3421+
//
3422+
// This causes a problem for serialization/deserialization of the Swift
3423+
// module. Imagine the Swift struct from 4) is used in the function return
3424+
// type. We cannot just serialize the non generic Swift struct, because on
3425+
// deserialization we would need to find its backing Clang declaration
3426+
// (the C++ class template instantiation), and it won't be found in the
3427+
// general case. Only the C++ class template from step 1) is in the Clang
3428+
// AST.
3429+
//
3430+
// What we need is to serialize enough information to be
3431+
// able to instantiate C++ class template on deserialization. It turns out
3432+
// that all that information is conveniently covered by the BoundGenericType,
3433+
// which we store in this field. The field is set during the typechecking at
3434+
// the time when we instantiate the C++ class template.
3435+
//
3436+
// Alternative, and likely better solution long term, is to serialize the
3437+
// C++ class template instantiation into a synthetic Clang module, and load
3438+
// this Clang module on deserialization.
3439+
Type TemplateInstantiationType = Type();
3440+
34053441
public:
34063442
StructDecl(SourceLoc StructLoc, Identifier Name, SourceLoc NameLoc,
34073443
ArrayRef<TypeLoc> Inherited,
@@ -3445,6 +3481,9 @@ class StructDecl final : public NominalTypeDecl {
34453481
bool isCxxNonTrivial() const { return Bits.StructDecl.IsCxxNonTrivial; }
34463482

34473483
void setIsCxxNonTrivial(bool v) { Bits.StructDecl.IsCxxNonTrivial = v; }
3484+
3485+
Type getTemplateInstantiationType() const { return TemplateInstantiationType; }
3486+
void setTemplateInstantiationType(Type t) { TemplateInstantiationType = t; }
34483487
};
34493488

34503489
/// This is the base type for AncestryOptions. Each flag describes possible

include/swift/AST/DiagnosticsSema.def

+4
Original file line numberDiff line numberDiff line change
@@ -2788,6 +2788,10 @@ ERROR(unexportable_clang_function_type,none,
27882788
"it may use anonymous types or types defined outside of a module",
27892789
(Type))
27902790

2791+
ERROR(cxx_class_instantiation_failed,none,
2792+
"couldn't instantiate a C++ class template",
2793+
())
2794+
27912795
WARNING(warn_implementation_only_conflict,none,
27922796
"%0 inconsistently imported as implementation-only",
27932797
(Identifier))

include/swift/ClangImporter/ClangImporter.h

+4
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ class ClangImporter final : public ClangModuleLoader {
240240
StringRef relatedEntityKind,
241241
llvm::function_ref<void(TypeDecl *)> receiver) override;
242242

243+
StructDecl *
244+
instantiateCXXClassTemplate(clang::ClassTemplateDecl *decl,
245+
ArrayRef<clang::TemplateArgument> arguments) override;
246+
243247
/// Just like Decl::getClangNode() except we look through to the 'Code'
244248
/// enum of an error wrapper struct.
245249
ClangNode getEffectiveClangNode(const Decl *decl) const;

lib/ClangImporter/ClangImporter.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -4150,3 +4150,25 @@ clang::FunctionDecl *ClangImporter::instantiateCXXFunctionTemplate(
41504150
sema.InstantiateFunctionDefinition(clang::SourceLocation(), spec);
41514151
return spec;
41524152
}
4153+
4154+
StructDecl *
4155+
ClangImporter::instantiateCXXClassTemplate(
4156+
clang::ClassTemplateDecl *decl,
4157+
ArrayRef<clang::TemplateArgument> arguments) {
4158+
void *InsertPos = nullptr;
4159+
auto *ctsd = decl->findSpecialization(arguments, InsertPos);
4160+
if (!ctsd) {
4161+
ctsd = clang::ClassTemplateSpecializationDecl::Create(
4162+
decl->getASTContext(), decl->getTemplatedDecl()->getTagKind(),
4163+
decl->getDeclContext(), decl->getTemplatedDecl()->getBeginLoc(),
4164+
decl->getLocation(), decl, arguments, nullptr);
4165+
decl->AddSpecialization(ctsd, InsertPos);
4166+
}
4167+
4168+
auto CanonType = decl->getASTContext().getTypeDeclType(ctsd);
4169+
assert(isa<clang::RecordType>(CanonType) &&
4170+
"type of non-dependent specialization is not a RecordType");
4171+
4172+
return dyn_cast_or_null<StructDecl>(
4173+
Impl.importDecl(ctsd, Impl.CurrentVersion));
4174+
}

lib/ClangImporter/ImportDecl.cpp

+37-20
Original file line numberDiff line numberDiff line change
@@ -3575,13 +3575,6 @@ namespace {
35753575
decl->getDefinition());
35763576
assert(def && "Class template instantiation didn't have definition");
35773577

3578-
// If this type is fully specialized (i.e. "Foo<>" or "Foo<int, int>"),
3579-
// bail to prevent a crash.
3580-
// TODO: we should be able to support fully specialized class templates.
3581-
// See SR-13775 for more info.
3582-
if (def->getTypeAsWritten())
3583-
return nullptr;
3584-
35853578
// FIXME: This will instantiate all members of the specialization (and detect
35863579
// instantiation failures in them), which can be more than is necessary
35873580
// and is more than what Clang does. As a result we reject some C++
@@ -3592,19 +3585,6 @@ namespace {
35923585
return VisitCXXRecordDecl(def);
35933586
}
35943587

3595-
Decl *VisitClassTemplateDecl(const clang::ClassTemplateDecl *decl) {
3596-
// When loading a namespace's sub-decls, we won't add template
3597-
// specilizations, so make sure to do that here.
3598-
for (auto spec : decl->specializations()) {
3599-
if (auto importedSpec = Impl.importDecl(spec, getVersion())) {
3600-
if (auto namespaceDecl =
3601-
dyn_cast<EnumDecl>(importedSpec->getDeclContext()))
3602-
namespaceDecl->addMember(importedSpec);
3603-
}
3604-
}
3605-
return nullptr;
3606-
}
3607-
36083588
Decl *VisitClassTemplatePartialSpecializationDecl(
36093589
const clang::ClassTemplatePartialSpecializationDecl *decl) {
36103590
// Note: partial template specializations are not imported.
@@ -4256,6 +4236,43 @@ namespace {
42564236
correctSwiftName, None, decl);
42574237
}
42584238

4239+
Decl *VisitClassTemplateDecl(const clang::ClassTemplateDecl *decl) {
4240+
// When loading a namespace's sub-decls, we won't add template
4241+
// specilizations, so make sure to do that here.
4242+
for (auto spec : decl->specializations()) {
4243+
if (auto importedSpec = Impl.importDecl(spec, getVersion())) {
4244+
if (auto namespaceDecl =
4245+
dyn_cast<EnumDecl>(importedSpec->getDeclContext()))
4246+
namespaceDecl->addMember(importedSpec);
4247+
}
4248+
}
4249+
4250+
Optional<ImportedName> correctSwiftName;
4251+
auto importedName = importFullName(decl, correctSwiftName);
4252+
auto name = importedName.getDeclName().getBaseIdentifier();
4253+
if (name.empty())
4254+
return nullptr;
4255+
auto loc = Impl.importSourceLoc(decl->getLocation());
4256+
auto dc = Impl.importDeclContextOf(
4257+
decl, importedName.getEffectiveContext());
4258+
4259+
SmallVector<GenericTypeParamDecl *, 4> genericParams;
4260+
for (auto &param : *decl->getTemplateParameters()) {
4261+
auto genericParamDecl = Impl.createDeclWithClangNode<GenericTypeParamDecl>(
4262+
param, AccessLevel::Public, dc,
4263+
Impl.SwiftContext.getIdentifier(param->getName()),
4264+
Impl.importSourceLoc(param->getLocation()),
4265+
/*depth*/ 0, /*index*/ genericParams.size());
4266+
genericParams.push_back(genericParamDecl);
4267+
}
4268+
auto genericParamList = GenericParamList::create(
4269+
Impl.SwiftContext, loc, genericParams, loc);
4270+
4271+
auto structDecl = Impl.createDeclWithClangNode<StructDecl>(
4272+
decl, AccessLevel::Public, loc, name, loc, None, genericParamList, dc);
4273+
return structDecl;
4274+
}
4275+
42594276
Decl *VisitUsingDecl(const clang::UsingDecl *decl) {
42604277
// Using declarations are not imported.
42614278
return nullptr;

lib/Sema/TypeCheckType.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
#include "swift/ClangImporter/ClangImporter.h"
4545
#include "swift/Strings.h"
4646
#include "swift/Subsystems.h"
47+
#include "clang/AST/ASTContext.h"
48+
#include "clang/AST/DeclBase.h"
49+
#include "clang/AST/DeclTemplate.h"
4750
#include "llvm/ADT/APInt.h"
4851
#include "llvm/ADT/SmallPtrSet.h"
4952
#include "llvm/ADT/SmallString.h"
@@ -834,6 +837,46 @@ static Type applyGenericArguments(Type type, TypeResolution resolution,
834837
diags.diagnose(loc, diag::use_of_void_pointer, "").
835838
fixItReplace(generic->getSourceRange(), "UnsafeRawPointer");
836839
}
840+
841+
if (auto clangDecl = decl->getClangDecl()) {
842+
if (auto classTemplateDecl =
843+
dyn_cast<clang::ClassTemplateDecl>(clangDecl)) {
844+
SmallVector<Type, 2> typesOfGenericArgs;
845+
for (auto typeRepr : generic->getGenericArgs()) {
846+
typesOfGenericArgs.push_back(resolution.resolveType(typeRepr));
847+
}
848+
849+
SmallVector<clang::TemplateArgument, 2> templateArguments;
850+
std::unique_ptr<TemplateInstantiationError> error =
851+
ctx.getClangTemplateArguments(
852+
classTemplateDecl->getTemplateParameters(), typesOfGenericArgs,
853+
templateArguments);
854+
855+
if (error) {
856+
std::string failedTypesStr;
857+
llvm::raw_string_ostream failedTypesStrStream(failedTypesStr);
858+
llvm::interleaveComma(error->failedTypes, failedTypesStrStream);
859+
// TODO: This error message should not reference implementation details.
860+
// See: https://github.com/apple/swift/pull/33053#discussion_r477003350
861+
ctx.Diags.diagnose(
862+
loc, diag::unable_to_convert_generic_swift_types.ID,
863+
{classTemplateDecl->getName(), StringRef(failedTypesStr)});
864+
return ErrorType::get(ctx);
865+
}
866+
867+
auto *clangModuleLoader = decl->getASTContext().getClangModuleLoader();
868+
auto instantiatedDecl = clangModuleLoader->instantiateCXXClassTemplate(
869+
const_cast<clang::ClassTemplateDecl *>(classTemplateDecl),
870+
templateArguments);
871+
if (instantiatedDecl) {
872+
instantiatedDecl->setTemplateInstantiationType(result);
873+
return instantiatedDecl->getDeclaredInterfaceType();
874+
} else {
875+
diags.diagnose(loc, diag::cxx_class_instantiation_failed);
876+
return ErrorType::get(ctx);
877+
}
878+
}
879+
}
837880
return result;
838881
}
839882

lib/Serialization/Deserialization.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "swift/Serialization/SerializedModuleLoader.h"
3535
#include "swift/Basic/Defer.h"
3636
#include "swift/Basic/Statistic.h"
37+
#include "clang/AST/DeclTemplate.h"
3738
#include "llvm/ADT/Statistic.h"
3839
#include "llvm/Support/Compiler.h"
3940
#include "llvm/Support/Debug.h"
@@ -5353,6 +5354,31 @@ class TypeDeserializer {
53535354
genericArgs.push_back(argTy.get());
53545355
}
53555356

5357+
if (auto clangDecl = nominal->getClangDecl()) {
5358+
if (auto ctd = dyn_cast<clang::ClassTemplateDecl>(clangDecl)) {
5359+
auto clangImporter = static_cast<ClangImporter *>(
5360+
nominal->getASTContext().getClangModuleLoader());
5361+
5362+
SmallVector<Type, 2> typesOfGenericArgs;
5363+
for (auto arg : genericArgs) {
5364+
typesOfGenericArgs.push_back(arg);
5365+
}
5366+
5367+
SmallVector<clang::TemplateArgument, 2> templateArguments;
5368+
std::unique_ptr<TemplateInstantiationError> error =
5369+
ctx.getClangTemplateArguments(ctd->getTemplateParameters(),
5370+
typesOfGenericArgs,
5371+
templateArguments);
5372+
5373+
auto instantiation = clangImporter->instantiateCXXClassTemplate(
5374+
const_cast<clang::ClassTemplateDecl *>(ctd), templateArguments);
5375+
5376+
instantiation->setTemplateInstantiationType(
5377+
BoundGenericType::get(nominal, parentTy, genericArgs));
5378+
return instantiation->getDeclaredInterfaceType();
5379+
}
5380+
}
5381+
53565382
return BoundGenericType::get(nominal, parentTy, genericArgs);
53575383
}
53585384

lib/Serialization/Serialization.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "swift/Demangling/ManglingMacros.h"
4848
#include "swift/Serialization/SerializationOptions.h"
4949
#include "swift/Strings.h"
50+
#include "clang/AST/DeclTemplate.h"
5051
#include "llvm/ADT/SmallSet.h"
5152
#include "llvm/ADT/SmallString.h"
5253
#include "llvm/ADT/StringExtras.h"
@@ -73,6 +74,7 @@ using namespace llvm::support;
7374
using swift::version::Version;
7475
using llvm::BCBlockRAII;
7576

77+
7678
ASTContext &SerializerBase::getASTContext() {
7779
return M->getASTContext();
7880
}
@@ -626,13 +628,25 @@ DeclID Serializer::addDeclRef(const Decl *D, bool allowTypeAliasXRef) {
626628
}
627629

628630
serialization::TypeID Serializer::addTypeRef(Type ty) {
631+
Type typeToSerialize = ty;
632+
if (ty) {
633+
if (auto nominalDecl = ty->getAnyNominal()) {
634+
if (auto structDecl = dyn_cast<StructDecl>(nominalDecl)) {
635+
if (auto templateInstantiationType =
636+
structDecl->getTemplateInstantiationType()) {
637+
typeToSerialize = templateInstantiationType;
638+
}
639+
}
640+
}
641+
}
642+
629643
#ifndef NDEBUG
630-
PrettyStackTraceType trace(M->getASTContext(), "serializing", ty);
644+
PrettyStackTraceType trace(M->getASTContext(), "serializing", typeToSerialize);
631645
assert(M->getASTContext().LangOpts.AllowModuleWithCompilerErrors ||
632-
!ty || !ty->hasError() && "serializing type with an error");
646+
!typeToSerialize || !typeToSerialize->hasError() && "serializing type with an error");
633647
#endif
634648

635-
return TypesToSerialize.addRef(ty);
649+
return TypesToSerialize.addRef(typeToSerialize);
636650
}
637651

638652
serialization::ClangTypeID Serializer::addClangTypeRef(const clang::Type *ty) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import ClassTemplateForSwiftModule
2+
3+
public func makeWrappedMagicNumber() -> MagicWrapper<IntWrapper> {
4+
let t = IntWrapper(value: 42)
5+
return MagicWrapper<IntWrapper>(t: t)
6+
}
7+
8+
public func readWrappedMagicNumber(_ i: inout MagicWrapper<IntWrapper>) -> CInt {
9+
return i.getValuePlusArg(13)
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import ClassTemplateInNamespaceForSwiftModule
2+
3+
public func receiveShip(_ i: inout Space.Ship<Bool>) {}
4+
5+
public func returnShip() -> Space.Ship<Bool> {
6+
return Space.Ship<Bool>()
7+
}
8+
9+
public func receiveShipWithEngine(_ i: inout Space.Ship<Engine.Turbojet>) {}
10+
11+
public func returnShipWithEngine() -> Space.Ship<Engine.Turbojet> {
12+
return Space.Ship<Engine.Turbojet>()
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import ClassTemplateNestedTypeForSwiftModule
2+
3+
public func receiveShipEngine(_ i: inout Ship<Bool>.Engine) {}
4+
5+
public func returnShipEngine() -> Ship<Bool>.Engine {
6+
return Ship<Bool>.Engine()
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_FOR_SWIFT_MODULE_H
2+
#define TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_FOR_SWIFT_MODULE_H
3+
4+
struct IntWrapper {
5+
int value;
6+
int getValue() const { return value; }
7+
};
8+
9+
template<class T>
10+
struct MagicWrapper {
11+
T t;
12+
int getValuePlusArg(int arg) const { return t.getValue() + arg; }
13+
};
14+
15+
#endif // TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_FOR_SWIFT_MODULE_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_IN_NAMESPACE_FOR_SWIFT_MODULE_H
2+
#define TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_IN_NAMESPACE_FOR_SWIFT_MODULE_H
3+
4+
namespace Space {
5+
template <class T> struct Ship { T t; };
6+
} // namespace Space
7+
8+
namespace Engine {
9+
struct Turbojet {};
10+
} // namespace Engine
11+
12+
#endif // TEST_INTEROP_CXX_TEMPLATES_INPUTS_CLASS_TEMPLATE_IN_NAMESPACE_FOR_SWIFT_MODULE_H

0 commit comments

Comments
 (0)