Skip to content

Commit 5a68861

Browse files
committed
Emit reabstraction thunks for implicit conversions between T.TangentType and Optional<T>.TangentType
Fixes #77871
1 parent 23c577d commit 5a68861

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

lib/SILGen/SILGenPoly.cpp

+98
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
//===----------------------------------------------------------------------===//
8484

8585
#define DEBUG_TYPE "silgen-poly"
86+
#include "ArgumentSource.h"
8687
#include "ExecutorBreadcrumb.h"
8788
#include "FunctionInputGenerator.h"
8889
#include "Initialization.h"
@@ -675,6 +676,103 @@ ManagedValue Transform::transform(ManagedValue v,
675676
return std::move(result).getAsSingleValue(SGF, Loc);
676677
}
677678

679+
// - T.TangentVector to Optional<T>.TangentVector
680+
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
681+
// So we just need to call appropriate .init on it.
682+
// However, we might have T.TangentVector == T, so we need to calculate all
683+
// required types first.
684+
if (CanType optionalTy = outputSubstType.getNominalParent(); // `Optional<T>`
685+
optionalTy && (bool)optionalTy.getOptionalObjectType()) {
686+
// `T`
687+
CanType wrappedType = optionalTy.getOptionalObjectType();
688+
// Check that T.TangentVector is indeed inputSubstType (this also handles
689+
// case when T == T.TangentVector)
690+
auto inputTanSpace =
691+
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
692+
if (inputTanSpace && inputTanSpace->getCanonicalType() == inputSubstType) {
693+
auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal();
694+
// Look up the `Optional<T>.TangentVector.init` declaration.
695+
auto initLookup =
696+
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
697+
ConstructorDecl *constructorDecl = nullptr;
698+
for (auto *candidate : initLookup) {
699+
auto candidateModule = candidate->getModuleContext();
700+
if (candidateModule->getName() ==
701+
SGF. getASTContext().Id_Differentiation ||
702+
candidateModule->isStdlibModule()) {
703+
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
704+
constructorDecl = cast<ConstructorDecl>(candidate);
705+
#ifdef NDEBUG
706+
break;
707+
#endif
708+
}
709+
}
710+
assert(constructorDecl && "No `Optional.TangentVector.init`");
711+
712+
// `T.TangentVector`
713+
CanType wrappedTanType = inputTanSpace->getCanonicalType();
714+
// `Optional<T.TangentVector>`
715+
CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType();
716+
717+
const TypeLowering &optTL = SGF.getTypeLowering(optionalOfWrappedTanType);
718+
auto optVal = SGF.emitInjectOptional(Loc, optTL, ctxt,
719+
[&](SGFContext objectCtxt) {
720+
return v;
721+
});
722+
auto *diffProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
723+
auto diffConf = lookupConformance(wrappedType, diffProto);
724+
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
725+
ConcreteDeclRef initDecl(constructorDecl,
726+
SubstitutionMap::get(constructorDecl->getGenericSignature(),
727+
{wrappedType}, {diffConf}));
728+
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
729+
args.add(Loc, RValue(SGF, {optVal}, optionalOfWrappedTanType));
730+
731+
auto result = SGF.emitApplyAllocatingInitializer(Loc, initDecl,
732+
std::move(args), outputSubstType, ctxt);
733+
if (result.isInContext())
734+
return ManagedValue::forInContext();
735+
return std::move(result).getAsSingleValue(SGF, Loc);
736+
}
737+
}
738+
739+
// - Optional<T>.TangentVector to T.TangentVector.
740+
if (CanType optionalTy = inputSubstType.getNominalParent(); // `Optional<T>`
741+
optionalTy && (bool)optionalTy.getOptionalObjectType()) {
742+
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
743+
// Check that T.TangentVector is indeed outputSubstType (this also handles
744+
// case when T == T.TangentVector)
745+
auto outputTanSpace =
746+
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
747+
if (outputTanSpace && outputTanSpace->getCanonicalType() == outputSubstType) {
748+
// Optional<T>.TangentVector should be a struct with a single
749+
// Optional<T.TangentVector> property. This is an implementation detail of
750+
// OptionalDifferentiation.swift
751+
// TODO: Maybe it would be better to have getters / setters here that we
752+
// can call and hide this implementation detail?
753+
StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct();
754+
VarDecl *wrappedValueVar = nullptr;
755+
if (optStructDecl) {
756+
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
757+
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
758+
}
759+
760+
EnumDecl *optDecl = wrappedValueVar ?
761+
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
762+
nullptr;
763+
764+
if (!optStructDecl || optDecl != SGF.getASTContext().getOptionalDecl())
765+
llvm_unreachable("Unexpected type of Optional.TangentVector");
766+
767+
FormalEvaluationScope scope(SGF);
768+
auto wrappedVal = SGF.B.createStructExtract(Loc, v, wrappedValueVar);
769+
return SGF.emitCheckedGetOptionalValueFrom(Loc, wrappedVal,
770+
/*isImplicitUnwrap*/ true,
771+
SGF.getTypeLowering(wrappedVal.getType()),
772+
ctxt);
773+
}
774+
}
775+
678776
// Should have handled the conversion in one of the cases above.
679777
v.dump();
680778
llvm_unreachable("Unhandled transform?");

lib/Sema/CSApply.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "clang/Sema/TemplateDeduction.h"
5353
#include "llvm/ADT/APFloat.h"
5454
#include "llvm/ADT/APInt.h"
55+
#include "llvm/ADT/STLExtras.h"
5556
#include "llvm/ADT/SmallString.h"
5657
#include "llvm/Support/Compiler.h"
5758
#include "llvm/Support/SaveAndRestore.h"
@@ -7499,8 +7500,18 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
74997500
fromEI.intoBuilder()
75007501
.withDifferentiabilityKind(toEI.getDifferentiabilityKind())
75017502
.build();
7502-
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult(),
7503-
newEI);
7503+
SmallVector<AnyFunctionType::Param, 4> params(fromFunc->getParams());
7504+
assert(params.size() == toFunc->getParams().size() && "unexpected @differentiable conversion");
7505+
// Propagate @noDerivate from target function type
7506+
for (auto paramAndIndex : llvm::enumerate(toFunc->getParams())) {
7507+
if (!paramAndIndex.value().isNoDerivative())
7508+
continue;
7509+
7510+
auto &param = params[paramAndIndex.index()];
7511+
param = param.withFlags(param.getParameterFlags().withNoDerivative(true));
7512+
}
7513+
7514+
fromFunc = FunctionType::get(params, fromFunc->getResult(), newEI);
75047515
switch (toEI.getDifferentiabilityKind()) {
75057516
// TODO: Ban `Normal` and `Forward` cases.
75067517
case DifferentiabilityKind::Normal:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/swiftlang/swift/issues/77871
4+
// Ensure we are correctl generating reabstraction thunks for Double <-> Optional<Double>
5+
// conversion for derivatives: for differential and pullback we need
6+
// to emit thunks to convert T.TangentVector <-> Optional<T>.TangentVector.
7+
8+
import _Differentiation
9+
10+
@differentiable(reverse)
11+
func testFunc(_ x: Double?) -> Double? {
12+
x! * x! * x!
13+
}
14+
print(pullback(at: 1.0, of: testFunc)(.init(1.0)) == 3.0)
15+

0 commit comments

Comments
 (0)