|
83 | 83 | //===----------------------------------------------------------------------===//
|
84 | 84 |
|
85 | 85 | #define DEBUG_TYPE "silgen-poly"
|
| 86 | +#include "ArgumentSource.h" |
86 | 87 | #include "ExecutorBreadcrumb.h"
|
87 | 88 | #include "FunctionInputGenerator.h"
|
88 | 89 | #include "Initialization.h"
|
@@ -675,6 +676,103 @@ ManagedValue Transform::transform(ManagedValue v,
|
675 | 676 | return std::move(result).getAsSingleValue(SGF, Loc);
|
676 | 677 | }
|
677 | 678 |
|
| 679 | + // - T.TangentVector to Optional<T>.TangentVector |
| 680 | + // Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector> |
| 681 | + // So we just need to call appropriate .init on it. |
| 682 | + // However, we might have T.TangentVector == T, so we need to calculate all |
| 683 | + // required types first. |
| 684 | + if (CanType optionalTy = outputSubstType.getNominalParent(); // `Optional<T>` |
| 685 | + optionalTy && (bool)optionalTy.getOptionalObjectType()) { |
| 686 | + // `T` |
| 687 | + CanType wrappedType = optionalTy.getOptionalObjectType(); |
| 688 | + // Check that T.TangentVector is indeed inputSubstType (this also handles |
| 689 | + // case when T == T.TangentVector) |
| 690 | + auto inputTanSpace = |
| 691 | + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 692 | + if (inputTanSpace && inputTanSpace->getCanonicalType() == inputSubstType) { |
| 693 | + auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal(); |
| 694 | + // Look up the `Optional<T>.TangentVector.init` declaration. |
| 695 | + auto initLookup = |
| 696 | + optionalTanDecl->lookupDirect(DeclBaseName::createConstructor()); |
| 697 | + ConstructorDecl *constructorDecl = nullptr; |
| 698 | + for (auto *candidate : initLookup) { |
| 699 | + auto candidateModule = candidate->getModuleContext(); |
| 700 | + if (candidateModule->getName() == |
| 701 | + SGF. getASTContext().Id_Differentiation || |
| 702 | + candidateModule->isStdlibModule()) { |
| 703 | + assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s"); |
| 704 | + constructorDecl = cast<ConstructorDecl>(candidate); |
| 705 | +#ifdef NDEBUG |
| 706 | + break; |
| 707 | +#endif |
| 708 | + } |
| 709 | + } |
| 710 | + assert(constructorDecl && "No `Optional.TangentVector.init`"); |
| 711 | + |
| 712 | + // `T.TangentVector` |
| 713 | + CanType wrappedTanType = inputTanSpace->getCanonicalType(); |
| 714 | + // `Optional<T.TangentVector>` |
| 715 | + CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType(); |
| 716 | + |
| 717 | + const TypeLowering &optTL = SGF.getTypeLowering(optionalOfWrappedTanType); |
| 718 | + auto optVal = SGF.emitInjectOptional(Loc, optTL, ctxt, |
| 719 | + [&](SGFContext objectCtxt) { |
| 720 | + return v; |
| 721 | + }); |
| 722 | + auto *diffProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Differentiable); |
| 723 | + auto diffConf = lookupConformance(wrappedType, diffProto); |
| 724 | + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
| 725 | + ConcreteDeclRef initDecl(constructorDecl, |
| 726 | + SubstitutionMap::get(constructorDecl->getGenericSignature(), |
| 727 | + {wrappedType}, {diffConf})); |
| 728 | + PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)}); |
| 729 | + args.add(Loc, RValue(SGF, {optVal}, optionalOfWrappedTanType)); |
| 730 | + |
| 731 | + auto result = SGF.emitApplyAllocatingInitializer(Loc, initDecl, |
| 732 | + std::move(args), outputSubstType, ctxt); |
| 733 | + if (result.isInContext()) |
| 734 | + return ManagedValue::forInContext(); |
| 735 | + return std::move(result).getAsSingleValue(SGF, Loc); |
| 736 | + } |
| 737 | + } |
| 738 | + |
| 739 | + // - Optional<T>.TangentVector to T.TangentVector. |
| 740 | + if (CanType optionalTy = inputSubstType.getNominalParent(); // `Optional<T>` |
| 741 | + optionalTy && (bool)optionalTy.getOptionalObjectType()) { |
| 742 | + CanType wrappedType = optionalTy.getOptionalObjectType(); // `T` |
| 743 | + // Check that T.TangentVector is indeed outputSubstType (this also handles |
| 744 | + // case when T == T.TangentVector) |
| 745 | + auto outputTanSpace = |
| 746 | + wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule()); |
| 747 | + if (outputTanSpace && outputTanSpace->getCanonicalType() == outputSubstType) { |
| 748 | + // Optional<T>.TangentVector should be a struct with a single |
| 749 | + // Optional<T.TangentVector> property. This is an implementation detail of |
| 750 | + // OptionalDifferentiation.swift |
| 751 | + // TODO: Maybe it would be better to have getters / setters here that we |
| 752 | + // can call and hide this implementation detail? |
| 753 | + StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct(); |
| 754 | + VarDecl *wrappedValueVar = nullptr; |
| 755 | + if (optStructDecl) { |
| 756 | + ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties(); |
| 757 | + wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr; |
| 758 | + } |
| 759 | + |
| 760 | + EnumDecl *optDecl = wrappedValueVar ? |
| 761 | + wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() : |
| 762 | + nullptr; |
| 763 | + |
| 764 | + if (!optStructDecl || optDecl != SGF.getASTContext().getOptionalDecl()) |
| 765 | + llvm_unreachable("Unexpected type of Optional.TangentVector"); |
| 766 | + |
| 767 | + FormalEvaluationScope scope(SGF); |
| 768 | + auto wrappedVal = SGF.B.createStructExtract(Loc, v, wrappedValueVar); |
| 769 | + return SGF.emitCheckedGetOptionalValueFrom(Loc, wrappedVal, |
| 770 | + /*isImplicitUnwrap*/ true, |
| 771 | + SGF.getTypeLowering(wrappedVal.getType()), |
| 772 | + ctxt); |
| 773 | + } |
| 774 | + } |
| 775 | + |
678 | 776 | // Should have handled the conversion in one of the cases above.
|
679 | 777 | v.dump();
|
680 | 778 | llvm_unreachable("Unhandled transform?");
|
|
0 commit comments