Skip to content

Commit b47b157

Browse files
committed
Factor out emission into separate helpers
1 parent 5a68861 commit b47b157

File tree

2 files changed

+111
-72
lines changed

2 files changed

+111
-72
lines changed

lib/SILGen/SILGenFunction.h

+17
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,23 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
25052505
CanSILFunctionType toType,
25062506
bool reorderSelf);
25072507

2508+
/// Emit conversion from T.TangentVector to Optional<T>.TangentVector.
2509+
ManagedValue
2510+
emitTangentVectorToOptionalTangentVector(SILLocation loc,
2511+
ManagedValue input,
2512+
CanType inputType,
2513+
CanType outputType,
2514+
SGFContext ctxt);
2515+
2516+
/// Emit conversion from Optional<T>.TangentVector to T.TangentVector.
2517+
ManagedValue
2518+
emitOptionalTangentVectorToTangentVector(SILLocation loc,
2519+
ManagedValue input,
2520+
CanType inputType,
2521+
CanType outputType,
2522+
SGFContext ctxt);
2523+
2524+
25082525
//===--------------------------------------------------------------------===//
25092526
// Back Deployment thunks
25102527
//===--------------------------------------------------------------------===//

lib/SILGen/SILGenPoly.cpp

+94-72
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,90 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
295295
});
296296
}
297297

298+
ManagedValue
299+
SILGenFunction::emitTangentVectorToOptionalTangentVector(SILLocation loc,
300+
ManagedValue input,
301+
CanType inputType,
302+
CanType outputType,
303+
SGFContext ctxt) {
304+
auto *optionalTanDecl = outputType.getNominalOrBoundGenericNominal();
305+
// Look up the `Optional<T>.TangentVector.init` declaration.
306+
auto initLookup =
307+
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
308+
ConstructorDecl *constructorDecl = nullptr;
309+
for (auto *candidate : initLookup) {
310+
auto candidateModule = candidate->getModuleContext();
311+
if (candidateModule->getName() ==
312+
getASTContext().Id_Differentiation ||
313+
candidateModule->isStdlibModule()) {
314+
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
315+
constructorDecl = cast<ConstructorDecl>(candidate);
316+
#ifdef NDEBUG
317+
break;
318+
#endif
319+
}
320+
}
321+
assert(constructorDecl && "No `Optional.TangentVector.init`");
322+
323+
// `Optional<T.TangentVector>`
324+
CanType optionalOfWrappedTanType = inputType.wrapInOptionalType();
325+
326+
const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType);
327+
auto optVal = emitInjectOptional(loc, optTL, ctxt,
328+
[&](SGFContext objectCtxt) {
329+
return input;
330+
});
331+
auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable);
332+
auto diffConf = lookupConformance(inputType, diffProto);
333+
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
334+
ConcreteDeclRef initDecl(constructorDecl,
335+
SubstitutionMap::get(constructorDecl->getGenericSignature(),
336+
{inputType}, {diffConf}));
337+
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
338+
args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType));
339+
340+
auto result = emitApplyAllocatingInitializer(loc, initDecl,
341+
std::move(args), outputType, ctxt);
342+
if (result.isInContext())
343+
return ManagedValue::forInContext();
344+
return std::move(result).getAsSingleValue(*this, loc);
345+
}
346+
347+
ManagedValue
348+
SILGenFunction::emitOptionalTangentVectorToTangentVector(SILLocation loc,
349+
ManagedValue input,
350+
CanType inputType,
351+
CanType outputType,
352+
SGFContext ctxt) {
353+
// Optional<T>.TangentVector should be a struct with a single
354+
// Optional<T.TangentVector> property. This is an implementation detail of
355+
// OptionalDifferentiation.swift
356+
// TODO: Maybe it would be better to have getters / setters here that we
357+
// can call and hide this implementation detail?
358+
StructDecl *optStructDecl = inputType.getStructOrBoundGenericStruct();
359+
VarDecl *wrappedValueVar = nullptr;
360+
if (optStructDecl) {
361+
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
362+
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
363+
}
364+
365+
EnumDecl *optDecl = wrappedValueVar ?
366+
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
367+
nullptr;
368+
369+
if (!optStructDecl || optDecl != getASTContext().getOptionalDecl())
370+
llvm_unreachable("Unexpected type of Optional.TangentVector");
371+
372+
FormalEvaluationScope scope(*this);
373+
auto wrappedVal = B.createStructExtract(loc, input, wrappedValueVar);
374+
return emitCheckedGetOptionalValueFrom(loc, wrappedVal,
375+
/*isImplicitUnwrap*/ true,
376+
getTypeLowering(wrappedVal.getType()),
377+
ctxt);
378+
}
379+
380+
381+
298382
/// Apply this transformation to an arbitrary value.
299383
RValue Transform::transform(RValue &&input,
300384
AbstractionPattern inputOrigType,
@@ -689,51 +773,11 @@ ManagedValue Transform::transform(ManagedValue v,
689773
// case when T == T.TangentVector)
690774
auto inputTanSpace =
691775
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
692-
if (inputTanSpace && inputTanSpace->getCanonicalType() == inputSubstType) {
693-
auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal();
694-
// Look up the `Optional<T>.TangentVector.init` declaration.
695-
auto initLookup =
696-
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
697-
ConstructorDecl *constructorDecl = nullptr;
698-
for (auto *candidate : initLookup) {
699-
auto candidateModule = candidate->getModuleContext();
700-
if (candidateModule->getName() ==
701-
SGF. getASTContext().Id_Differentiation ||
702-
candidateModule->isStdlibModule()) {
703-
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
704-
constructorDecl = cast<ConstructorDecl>(candidate);
705-
#ifdef NDEBUG
706-
break;
707-
#endif
708-
}
709-
}
710-
assert(constructorDecl && "No `Optional.TangentVector.init`");
711-
712-
// `T.TangentVector`
713-
CanType wrappedTanType = inputTanSpace->getCanonicalType();
714-
// `Optional<T.TangentVector>`
715-
CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType();
716-
717-
const TypeLowering &optTL = SGF.getTypeLowering(optionalOfWrappedTanType);
718-
auto optVal = SGF.emitInjectOptional(Loc, optTL, ctxt,
719-
[&](SGFContext objectCtxt) {
720-
return v;
721-
});
722-
auto *diffProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
723-
auto diffConf = lookupConformance(wrappedType, diffProto);
724-
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
725-
ConcreteDeclRef initDecl(constructorDecl,
726-
SubstitutionMap::get(constructorDecl->getGenericSignature(),
727-
{wrappedType}, {diffConf}));
728-
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
729-
args.add(Loc, RValue(SGF, {optVal}, optionalOfWrappedTanType));
730-
731-
auto result = SGF.emitApplyAllocatingInitializer(Loc, initDecl,
732-
std::move(args), outputSubstType, ctxt);
733-
if (result.isInContext())
734-
return ManagedValue::forInContext();
735-
return std::move(result).getAsSingleValue(SGF, Loc);
736-
}
776+
if (inputTanSpace &&
777+
inputTanSpace->getCanonicalType() == inputSubstType)
778+
return SGF.emitTangentVectorToOptionalTangentVector(Loc, v,
779+
inputSubstType, outputSubstType,
780+
ctxt);
737781
}
738782

739783
// - Optional<T>.TangentVector to T.TangentVector.
@@ -744,33 +788,11 @@ ManagedValue Transform::transform(ManagedValue v,
744788
// case when T == T.TangentVector)
745789
auto outputTanSpace =
746790
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
747-
if (outputTanSpace && outputTanSpace->getCanonicalType() == outputSubstType) {
748-
// Optional<T>.TangentVector should be a struct with a single
749-
// Optional<T.TangentVector> property. This is an implementation detail of
750-
// OptionalDifferentiation.swift
751-
// TODO: Maybe it would be better to have getters / setters here that we
752-
// can call and hide this implementation detail?
753-
StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct();
754-
VarDecl *wrappedValueVar = nullptr;
755-
if (optStructDecl) {
756-
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
757-
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
758-
}
759-
760-
EnumDecl *optDecl = wrappedValueVar ?
761-
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
762-
nullptr;
763-
764-
if (!optStructDecl || optDecl != SGF.getASTContext().getOptionalDecl())
765-
llvm_unreachable("Unexpected type of Optional.TangentVector");
766-
767-
FormalEvaluationScope scope(SGF);
768-
auto wrappedVal = SGF.B.createStructExtract(Loc, v, wrappedValueVar);
769-
return SGF.emitCheckedGetOptionalValueFrom(Loc, wrappedVal,
770-
/*isImplicitUnwrap*/ true,
771-
SGF.getTypeLowering(wrappedVal.getType()),
772-
ctxt);
773-
}
791+
if (outputTanSpace &&
792+
outputTanSpace->getCanonicalType() == outputSubstType)
793+
return SGF.emitOptionalTangentVectorToTangentVector(Loc, v,
794+
inputSubstType, outputSubstType,
795+
ctxt);
774796
}
775797

776798
// Should have handled the conversion in one of the cases above.

0 commit comments

Comments
 (0)