|
83 | 83 | //===----------------------------------------------------------------------===//
|
84 | 84 |
|
85 | 85 | #define DEBUG_TYPE "silgen-poly"
|
| 86 | +#include "ArgumentSource.h" |
86 | 87 | #include "ExecutorBreadcrumb.h"
|
87 | 88 | #include "FunctionInputGenerator.h"
|
88 | 89 | #include "Initialization.h"
|
@@ -294,6 +295,67 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
|
294 | 295 | });
|
295 | 296 | }
|
296 | 297 |
|
| 298 | +// Convert T.TangentVector to Optional<T>.TangentVector. |
| 299 | +// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector> |
| 300 | +// So we just need to call appropriate .init on it. |
| 301 | +ManagedValue SILGenFunction::emitTangentVectorToOptionalTangentVector( |
| 302 | + SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType, |
| 303 | + CanType outputType, SGFContext ctxt) { |
| 304 | + // Look up the `Optional<T>.TangentVector.init` declaration. |
| 305 | + auto *constructorDecl = getASTContext().getOptionalTanInitDecl(outputType); |
| 306 | + |
| 307 | + // `Optional<T.TangentVector>` |
| 308 | + CanType optionalOfWrappedTanType = inputType.wrapInOptionalType(); |
| 309 | + |
| 310 | + const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType); |
| 311 | + auto optVal = emitInjectOptional( |
| 312 | + loc, optTL, SGFContext(), [&](SGFContext objectCtxt) { return input; }); |
| 313 | + |
| 314 | + auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable); |
| 315 | + auto diffConf = lookupConformance(wrappedType, diffProto); |
| 316 | + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
| 317 | + ConcreteDeclRef initDecl( |
| 318 | + constructorDecl, |
| 319 | + SubstitutionMap::get(constructorDecl->getGenericSignature(), |
| 320 | + {wrappedType}, {diffConf})); |
| 321 | + PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)}); |
| 322 | + args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType)); |
| 323 | + |
| 324 | + auto result = emitApplyAllocatingInitializer(loc, initDecl, std::move(args), |
| 325 | + Type(), ctxt); |
| 326 | + return std::move(result).getScalarValue(); |
| 327 | +} |
| 328 | + |
| 329 | +ManagedValue SILGenFunction::emitOptionalTangentVectorToTangentVector( |
| 330 | + SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType, |
| 331 | + CanType outputType, SGFContext ctxt) { |
| 332 | + // Optional<T>.TangentVector should be a struct with a single |
| 333 | + // Optional<T.TangentVector> `value` property. This is an implementation |
| 334 | + // detail of OptionalDifferentiation.swift |
| 335 | + // TODO: Maybe it would be better to have explicit getters / setters here that we can |
| 336 | + // call and hide this implementation detail? |
| 337 | + VarDecl *wrappedValueVar = getASTContext().getOptionalTanValueDecl(inputType); |
| 338 | + // `Optional<T.TangentVector>` |
| 339 | + CanType optionalOfWrappedTanType = outputType.wrapInOptionalType(); |
| 340 | + |
| 341 | + FormalEvaluationScope scope(*this); |
| 342 | + |
| 343 | + auto sig = wrappedValueVar->getDeclContext()->getGenericSignatureOfContext(); |
| 344 | + auto *diffProto = |
| 345 | + getASTContext().getProtocol(KnownProtocolKind::Differentiable); |
| 346 | + auto diffConf = lookupConformance(wrappedType, diffProto); |
| 347 | + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
| 348 | + |
| 349 | + auto wrappedVal = emitRValueForStorageLoad( |
| 350 | + loc, input, inputType, /*super*/ false, wrappedValueVar, |
| 351 | + PreparedArguments(), SubstitutionMap::get(sig, {wrappedType}, {diffConf}), |
| 352 | + AccessSemantics::Ordinary, optionalOfWrappedTanType, SGFContext()); |
| 353 | + |
| 354 | + return emitCheckedGetOptionalValueFrom( |
| 355 | + loc, std::move(wrappedVal).getScalarValue(), |
| 356 | + /*isImplicitUnwrap*/ true, getTypeLowering(optionalOfWrappedTanType), ctxt); |
| 357 | +} |
| 358 | + |
297 | 359 | /// Apply this transformation to an arbitrary value.
|
298 | 360 | RValue Transform::transform(RValue &&input,
|
299 | 361 | AbstractionPattern inputOrigType,
|
@@ -675,6 +737,54 @@ ManagedValue Transform::transform(ManagedValue v,
|
675 | 737 | return std::move(result).getAsSingleValue(SGF, Loc);
|
676 | 738 | }
|
677 | 739 |
|
| 740 | + // - T.TangentVector to Optional<T>.TangentVector |
| 741 | + // Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector> |
| 742 | + // So we just need to call appropriate .init on it. |
| 743 | + // However, we might have T.TangentVector == T, so we need to calculate all |
| 744 | + // required types first. |
| 745 | + { |
| 746 | + CanType optionalTy = isa<NominalType>(outputSubstType) |
| 747 | + ? outputSubstType.getNominalParent() |
| 748 | + : CanType(); // `Optional<T>` |
| 749 | + if (optionalTy && (bool)optionalTy.getOptionalObjectType()) { |
| 750 | + CanType wrappedType = optionalTy.getOptionalObjectType(); // `T` |
| 751 | + // Check that T.TangentVector is indeed inputSubstType (this also handles |
| 752 | + // case when T == T.TangentVector). |
| 753 | + // Also check that outputSubstType is an Optional<T>.TangentVector. |
| 754 | + auto inputTanSpace = |
| 755 | + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 756 | + auto outputTanSpace = |
| 757 | + optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 758 | + if (inputTanSpace && outputTanSpace && |
| 759 | + inputTanSpace->getCanonicalType() == inputSubstType && |
| 760 | + outputTanSpace->getCanonicalType() == outputSubstType) |
| 761 | + return SGF.emitTangentVectorToOptionalTangentVector( |
| 762 | + Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt); |
| 763 | + } |
| 764 | + } |
| 765 | + |
| 766 | + // - Optional<T>.TangentVector to T.TangentVector. |
| 767 | + { |
| 768 | + CanType optionalTy = isa<NominalType>(inputSubstType) |
| 769 | + ? inputSubstType.getNominalParent() |
| 770 | + : CanType(); // `Optional<T>` |
| 771 | + if (optionalTy && (bool)optionalTy.getOptionalObjectType()) { |
| 772 | + CanType wrappedType = optionalTy.getOptionalObjectType(); // `T` |
| 773 | + // Check that T.TangentVector is indeed outputSubstType (this also handles |
| 774 | + // case when T == T.TangentVector) |
| 775 | + // Also check that inputSubstType is an Optional<T>.TangentVector |
| 776 | + auto inputTanSpace = |
| 777 | + optionalTy->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 778 | + auto outputTanSpace = |
| 779 | + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 780 | + if (inputTanSpace && outputTanSpace && |
| 781 | + inputTanSpace->getCanonicalType() == inputSubstType && |
| 782 | + outputTanSpace->getCanonicalType() == outputSubstType) |
| 783 | + return SGF.emitOptionalTangentVectorToTangentVector( |
| 784 | + Loc, v, wrappedType, inputSubstType, outputSubstType, ctxt); |
| 785 | + } |
| 786 | + } |
| 787 | + |
678 | 788 | // Should have handled the conversion in one of the cases above.
|
679 | 789 | v.dump();
|
680 | 790 | llvm_unreachable("Unhandled transform?");
|
|
0 commit comments