Skip to content

Commit 1e7a1d9

Browse files
authored
Emit reabstraction thunks for implicit conversions between T.TangentType and Optional<T>.TangentType (#78076)
1 parent ac2603c commit 1e7a1d9

File tree

8 files changed

+250
-58
lines changed

8 files changed

+250
-58
lines changed

include/swift/AST/ASTContext.h

+10
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,16 @@ class ASTContext final {
11211121
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
11221122
llvm::SetVector<AutoDiffConfig> &results);
11231123

1124+
/// Given `Optional<T>.TangentVector` type, retrieve the
1125+
/// `Optional<T>.TangentVector.init` declaration.
1126+
ConstructorDecl *getOptionalTanInitDecl(CanType optionalTanType);
1127+
1128+
/// Optional<T>.TangentVector is a struct with a single
1129+
/// Optional<T.TangentVector> `value` property. This is an implementation
1130+
/// detail of OptionalDifferentiation.swift. Retrieve `VarDecl` corresponding
1131+
/// to this property.
1132+
VarDecl *getOptionalTanValueDecl(CanType optionalTanType);
1133+
11241134
/// Retrieve the next macro expansion discriminator within the given
11251135
/// name and context.
11261136
unsigned getNextMacroDiscriminator(MacroDiscriminatorContext context,

lib/AST/ASTContext.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,12 @@ struct ASTContext::Implementation {
343343
/// The declaration of Swift.Optional<T>.None.
344344
EnumElementDecl *OptionalNoneDecl = nullptr;
345345

346+
/// The declaration of Optional<T>.TangentVector.init
347+
ConstructorDecl *OptionalTanInitDecl = nullptr;
348+
349+
/// The declaration of Optional<T>.TangentVector.value
350+
VarDecl *OptionalTanValueDecl = nullptr;
351+
346352
/// The declaration of Swift.Void.
347353
TypeAliasDecl *VoidDecl = nullptr;
348354

@@ -2242,6 +2248,52 @@ void ASTContext::loadObjCMethods(
22422248
}
22432249
}
22442250

2251+
ConstructorDecl *ASTContext::getOptionalTanInitDecl(CanType optionalTanType) {
2252+
if (!getImpl().OptionalTanInitDecl) {
2253+
auto *optionalTanDecl = optionalTanType.getNominalOrBoundGenericNominal();
2254+
// Look up the `Optional<T>.TangentVector.init` declaration.
2255+
auto initLookup =
2256+
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2257+
ConstructorDecl *constructorDecl = nullptr;
2258+
for (auto *candidate : initLookup) {
2259+
auto candidateModule = candidate->getModuleContext();
2260+
if (candidateModule->getName() == Id_Differentiation ||
2261+
candidateModule->isStdlibModule()) {
2262+
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2263+
constructorDecl = cast<ConstructorDecl>(candidate);
2264+
#ifdef NDEBUG
2265+
break;
2266+
#endif
2267+
}
2268+
}
2269+
assert(constructorDecl && "No `Optional.TangentVector.init`");
2270+
2271+
getImpl().OptionalTanInitDecl = constructorDecl;
2272+
}
2273+
2274+
return getImpl().OptionalTanInitDecl;
2275+
}
2276+
2277+
VarDecl *ASTContext::getOptionalTanValueDecl(CanType optionalTanType) {
2278+
if (!getImpl().OptionalTanValueDecl) {
2279+
// TODO: Maybe it would be better to have getters / setters here that we
2280+
// can call and hide this implementation detail?
2281+
StructDecl *optStructDecl = optionalTanType.getStructOrBoundGenericStruct();
2282+
assert(optStructDecl && "Unexpected type of Optional.TangentVector");
2283+
2284+
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
2285+
assert(properties.size() == 1 && "Unexpected type of Optional.TangentVector");
2286+
VarDecl *wrappedValueVar = properties[0];
2287+
2288+
assert(wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() ==
2289+
getOptionalDecl() && "Unexpected type of Optional.TangentVector");
2290+
2291+
getImpl().OptionalTanValueDecl = wrappedValueVar;
2292+
}
2293+
2294+
return getImpl().OptionalTanValueDecl;
2295+
}
2296+
22452297
void ASTContext::loadDerivativeFunctionConfigurations(
22462298
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
22472299
llvm::SetVector<AutoDiffConfig> &results) {

lib/AST/ASTDumper.cpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "swift/AST/ASTPrinter.h"
2020
#include "swift/AST/ASTVisitor.h"
2121
#include "swift/AST/Attr.h"
22+
#include "swift/AST/AutoDiff.h"
2223
#include "swift/AST/ClangModuleLoader.h"
2324
#include "swift/AST/ForeignAsyncConvention.h"
2425
#include "swift/AST/ForeignErrorConvention.h"
@@ -6159,7 +6160,7 @@ namespace {
61596160
}
61606161

61616162
void printAnyFunctionTypeCommonRec(AnyFunctionType *T, Label label,
6162-
StringRef name) {
6163+
StringRef name) {
61636164
printCommon(name, label);
61646165

61656166
if (T->hasExtInfo()) {
@@ -6174,6 +6175,24 @@ namespace {
61746175
printFlag(T->isAsync(), "async");
61756176
printFlag(T->isThrowing(), "throws");
61766177
printFlag(T->hasSendingResult(), "sending_result");
6178+
if (T->isDifferentiable()) {
6179+
switch (T->getDifferentiabilityKind()) {
6180+
default:
6181+
llvm_unreachable("unexpected differentiability kind");
6182+
case DifferentiabilityKind::Reverse:
6183+
printFlag("@differentiable(reverse)");
6184+
break;
6185+
case DifferentiabilityKind::Forward:
6186+
printFlag("@differentiable(_forward)");
6187+
break;
6188+
case DifferentiabilityKind::Linear:
6189+
printFlag("@differentiable(_linear)");
6190+
break;
6191+
case DifferentiabilityKind::Normal:
6192+
printFlag("@differentiable");
6193+
break;
6194+
}
6195+
}
61776196
}
61786197
if (Type globalActor = T->getGlobalActor()) {
61796198
printFieldQuoted(globalActor.getString(), Label::always("global_actor"));

lib/SILGen/SILGenFunction.h

+19
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,25 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
25212521
CanSILFunctionType toType,
25222522
bool reorderSelf);
25232523

2524+
/// Emit conversion from T.TangentVector to Optional<T>.TangentVector.
2525+
ManagedValue
2526+
emitTangentVectorToOptionalTangentVector(SILLocation loc,
2527+
ManagedValue input,
2528+
CanType wrappedType, // `T`
2529+
CanType inputType, // `T.TangentVector`
2530+
CanType outputType, // `Optional<T>.TangentVector`
2531+
SGFContext ctxt);
2532+
2533+
/// Emit conversion from Optional<T>.TangentVector to T.TangentVector.
2534+
ManagedValue
2535+
emitOptionalTangentVectorToTangentVector(SILLocation loc,
2536+
ManagedValue input,
2537+
CanType wrappedType, // `T`
2538+
CanType inputType, // `Optional<T>.TangentVector`
2539+
CanType outputType, // `T.TangentVector`
2540+
SGFContext ctxt);
2541+
2542+
25242543
//===--------------------------------------------------------------------===//
25252544
// Back Deployment thunks
25262545
//===--------------------------------------------------------------------===//

lib/SILGen/SILGenPoly.cpp

+110
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
//===----------------------------------------------------------------------===//
8484

8585
#define DEBUG_TYPE "silgen-poly"
86+
#include "ArgumentSource.h"
8687
#include "ExecutorBreadcrumb.h"
8788
#include "FunctionInputGenerator.h"
8889
#include "Initialization.h"
@@ -294,6 +295,67 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
294295
});
295296
}
296297

298+
// Convert T.TangentVector to Optional<T>.TangentVector.
299+
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
300+
// So we just need to call appropriate .init on it.
301+
ManagedValue SILGenFunction::emitTangentVectorToOptionalTangentVector(
302+
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
303+
CanType outputType, SGFContext ctxt) {
304+
// Look up the `Optional<T>.TangentVector.init` declaration.
305+
auto *constructorDecl = getASTContext().getOptionalTanInitDecl(outputType);
306+
307+
// `Optional<T.TangentVector>`
308+
CanType optionalOfWrappedTanType = inputType.wrapInOptionalType();
309+
310+
const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType);
311+
auto optVal = emitInjectOptional(
312+
loc, optTL, SGFContext(), [&](SGFContext objectCtxt) { return input; });
313+
314+
auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable);
315+
auto diffConf = lookupConformance(wrappedType, diffProto);
316+
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
317+
ConcreteDeclRef initDecl(
318+
constructorDecl,
319+
SubstitutionMap::get(constructorDecl->getGenericSignature(),
320+
{wrappedType}, {diffConf}));
321+
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
322+
args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType));
323+
324+
auto result = emitApplyAllocatingInitializer(loc, initDecl, std::move(args),
325+
Type(), ctxt);
326+
return std::move(result).getScalarValue();
327+
}
328+
329+
ManagedValue SILGenFunction::emitOptionalTangentVectorToTangentVector(
330+
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
331+
CanType outputType, SGFContext ctxt) {
332+
// Optional<T>.TangentVector should be a struct with a single
333+
// Optional<T.TangentVector> `value` property. This is an implementation
334+
// detail of OptionalDifferentiation.swift
335+
// TODO: Maybe it would be better to have explicit getters / setters here that we can
336+
// call and hide this implementation detail?
337+
VarDecl *wrappedValueVar = getASTContext().getOptionalTanValueDecl(inputType);
338+
// `Optional<T.TangentVector>`
339+
CanType optionalOfWrappedTanType = outputType.wrapInOptionalType();
340+
341+
FormalEvaluationScope scope(*this);
342+
343+
auto sig = wrappedValueVar->getDeclContext()->getGenericSignatureOfContext();
344+
auto *diffProto =
345+
getASTContext().getProtocol(KnownProtocolKind::Differentiable);
346+
auto diffConf = lookupConformance(wrappedType, diffProto);
347+
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
348+
349+
auto wrappedVal = emitRValueForStorageLoad(
350+
loc, input, inputType, /*super*/ false, wrappedValueVar,
351+
PreparedArguments(), SubstitutionMap::get(sig, {wrappedType}, {diffConf}),
352+
AccessSemantics::Ordinary, optionalOfWrappedTanType, SGFContext());
353+
354+
return emitCheckedGetOptionalValueFrom(
355+
loc, std::move(wrappedVal).getScalarValue(),
356+
/*isImplicitUnwrap*/ true, getTypeLowering(optionalOfWrappedTanType), ctxt);
357+
}
358+
297359
/// Apply this transformation to an arbitrary value.
298360
RValue Transform::transform(RValue &&input,
299361
AbstractionPattern inputOrigType,
@@ -675,6 +737,54 @@ ManagedValue Transform::transform(ManagedValue v,
675737
return std::move(result).getAsSingleValue(SGF, Loc);
676738
}
677739

740+
// - T.TangentVector to Optional<T>.TangentVector
741+
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
742+
// So we just need to call appropriate .init on it.
743+
// However, we might have T.TangentVector == T, so we need to calculate all
744+
// required types first.
745+
{
746+
CanType optionalTy = isa<NominalType>(outputSubstType)
747+
? outputSubstType.getNominalParent()
748+
: CanType(); // `Optional<T>`
749+
if (optionalTy && (bool)optionalTy.getOptionalObjectType()) {
750+
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
751+
// Check that T.TangentVector is indeed inputSubstType (this also handles
752+
// case when T == T.TangentVector).
753+
// Also check that outputSubstType is an Optional<T>.TangentVector.
754+
auto inputTanSpace =
755+
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
756+
auto outputTanSpace =
757+
optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule());
758+
if (inputTanSpace && outputTanSpace &&
759+
inputTanSpace->getCanonicalType() == inputSubstType &&
760+
outputTanSpace->getCanonicalType() == outputSubstType)
761+
return SGF.emitTangentVectorToOptionalTangentVector(
762+
Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt);
763+
}
764+
}
765+
766+
// - Optional<T>.TangentVector to T.TangentVector.
767+
{
768+
CanType optionalTy = isa<NominalType>(inputSubstType)
769+
? inputSubstType.getNominalParent()
770+
: CanType(); // `Optional<T>`
771+
if (optionalTy && (bool)optionalTy.getOptionalObjectType()) {
772+
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
773+
// Check that T.TangentVector is indeed outputSubstType (this also handles
774+
// case when T == T.TangentVector)
775+
// Also check that inputSubstType is an Optional<T>.TangentVector
776+
auto inputTanSpace =
777+
optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule());
778+
auto outputTanSpace =
779+
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
780+
if (inputTanSpace && outputTanSpace &&
781+
inputTanSpace->getCanonicalType() == inputSubstType &&
782+
outputTanSpace->getCanonicalType() == outputSubstType)
783+
return SGF.emitOptionalTangentVectorToTangentVector(
784+
Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt);
785+
}
786+
}
787+
678788
// Should have handled the conversion in one of the cases above.
679789
v.dump();
680790
llvm_unreachable("Unhandled transform?");

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

+6-55
Original file line numberDiff line numberDiff line change
@@ -1852,27 +1852,9 @@ class PullbackCloner::Implementation final
18521852

18531853
auto adjOpt = getAdjointValue(bb, ei);
18541854
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
1855-
StructDecl *adjStructDecl =
1856-
adjStruct->getType().getStructOrBoundGenericStruct();
1857-
1858-
VarDecl *adjOptVar = nullptr;
1859-
if (adjStructDecl) {
1860-
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1861-
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1862-
}
1863-
1864-
EnumDecl *adjOptDecl =
1865-
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1866-
: nullptr;
1867-
1868-
// Optional<T>.TangentVector should be a struct with a single
1869-
// Optional<T.TangentVector> property. This is an implementation detail of
1870-
// OptionalDifferentiation.swift
1871-
// TODO: Maybe it would be better to have getters / setters here that we
1872-
// can call and hide this implementation detail?
1873-
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1874-
llvm_unreachable("Unexpected type of Optional.TangentVector");
18751855

1856+
VarDecl *adjOptVar =
1857+
getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType());
18761858
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);
18771859

18781860
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
@@ -1931,24 +1913,8 @@ class PullbackCloner::Implementation final
19311913
}
19321914

19331915
SILValue adjDest = getAdjointBuffer(bb, origEnum);
1934-
StructDecl *adjStructDecl =
1935-
adjDest->getType().getStructOrBoundGenericStruct();
1936-
1937-
VarDecl *adjOptVar = nullptr;
1938-
if (adjStructDecl) {
1939-
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1940-
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1941-
}
1942-
1943-
EnumDecl *adjOptDecl =
1944-
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1945-
: nullptr;
1946-
1947-
// Optional<T>.TangentVector should be a struct with a single
1948-
// Optional<T.TangentVector> property. This is an implementation detail of
1949-
// OptionalDifferentiation.swift
1950-
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1951-
llvm_unreachable("Unexpected type of Optional.TangentVector");
1916+
VarDecl *adjOptVar =
1917+
getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType());
19521918

19531919
SILLocation loc = origData->getLoc();
19541920
StructElementAddrInst *adjOpt =
@@ -2678,24 +2644,9 @@ AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
26782644
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
26792645
// `Optional<T>.TangentVector`
26802646
auto optionalTanTy = getRemappedTangentType(optionalTy);
2681-
auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal();
26822647
// Look up the `Optional<T>.TangentVector.init` declaration.
2683-
auto initLookup =
2684-
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2685-
ConstructorDecl *constructorDecl = nullptr;
2686-
for (auto *candidate : initLookup) {
2687-
auto candidateModule = candidate->getModuleContext();
2688-
if (candidateModule->getName() ==
2689-
builder.getASTContext().Id_Differentiation ||
2690-
candidateModule->isStdlibModule()) {
2691-
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2692-
constructorDecl = cast<ConstructorDecl>(candidate);
2693-
#ifdef NDEBUG
2694-
break;
2695-
#endif
2696-
}
2697-
}
2698-
assert(constructorDecl && "No `Optional.TangentVector.init`");
2648+
ConstructorDecl *constructorDecl =
2649+
getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType());
26992650

27002651
// Allocate a local buffer for the `Optional` adjoint value.
27012652
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);

lib/Sema/CSApply.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "clang/Sema/TemplateDeduction.h"
5353
#include "llvm/ADT/APFloat.h"
5454
#include "llvm/ADT/APInt.h"
55+
#include "llvm/ADT/STLExtras.h"
5556
#include "llvm/ADT/SmallString.h"
5657
#include "llvm/Support/Compiler.h"
5758
#include "llvm/Support/SaveAndRestore.h"
@@ -7538,8 +7539,20 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
75387539
fromEI.intoBuilder()
75397540
.withDifferentiabilityKind(toEI.getDifferentiabilityKind())
75407541
.build();
7541-
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult(),
7542-
newEI);
7542+
SmallVector<AnyFunctionType::Param, 4> params(fromFunc->getParams());
7543+
assert(params.size() == toFunc->getParams().size() &&
7544+
"unexpected @differentiable conversion");
7545+
// Propagate @noDerivate from target function type
7546+
for (auto paramAndIndex : llvm::enumerate(toFunc->getParams())) {
7547+
if (!paramAndIndex.value().isNoDerivative())
7548+
continue;
7549+
7550+
auto &param = params[paramAndIndex.index()];
7551+
param =
7552+
param.withFlags(param.getParameterFlags().withNoDerivative(true));
7553+
}
7554+
7555+
fromFunc = FunctionType::get(params, fromFunc->getResult(), newEI);
75437556
switch (toEI.getDifferentiabilityKind()) {
75447557
// TODO: Ban `Normal` and `Forward` cases.
75457558
case DifferentiabilityKind::Normal:

0 commit comments

Comments
 (0)