@@ -295,6 +295,90 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
295
295
});
296
296
}
297
297
298
+ ManagedValue
299
+ SILGenFunction::emitTangentVectorToOptionalTangentVector (SILLocation loc,
300
+ ManagedValue input,
301
+ CanType inputType,
302
+ CanType outputType,
303
+ SGFContext ctxt) {
304
+ auto *optionalTanDecl = outputType.getNominalOrBoundGenericNominal ();
305
+ // Look up the `Optional<T>.TangentVector.init` declaration.
306
+ auto initLookup =
307
+ optionalTanDecl->lookupDirect (DeclBaseName::createConstructor ());
308
+ ConstructorDecl *constructorDecl = nullptr ;
309
+ for (auto *candidate : initLookup) {
310
+ auto candidateModule = candidate->getModuleContext ();
311
+ if (candidateModule->getName () ==
312
+ getASTContext ().Id_Differentiation ||
313
+ candidateModule->isStdlibModule ()) {
314
+ assert (!constructorDecl && " Multiple `Optional.TangentVector.init`s" );
315
+ constructorDecl = cast<ConstructorDecl>(candidate);
316
+ #ifdef NDEBUG
317
+ break ;
318
+ #endif
319
+ }
320
+ }
321
+ assert (constructorDecl && " No `Optional.TangentVector.init`" );
322
+
323
+ // `Optional<T.TangentVector>`
324
+ CanType optionalOfWrappedTanType = inputType.wrapInOptionalType ();
325
+
326
+ const TypeLowering &optTL = getTypeLowering (optionalOfWrappedTanType);
327
+ auto optVal = emitInjectOptional (loc, optTL, ctxt,
328
+ [&](SGFContext objectCtxt) {
329
+ return input;
330
+ });
331
+ auto *diffProto = getASTContext ().getProtocol (KnownProtocolKind::Differentiable);
332
+ auto diffConf = lookupConformance (inputType, diffProto);
333
+ assert (!diffConf.isInvalid () && " Missing conformance to `Differentiable`" );
334
+ ConcreteDeclRef initDecl (constructorDecl,
335
+ SubstitutionMap::get (constructorDecl->getGenericSignature (),
336
+ {inputType}, {diffConf}));
337
+ PreparedArguments args ({AnyFunctionType::Param (optionalOfWrappedTanType)});
338
+ args.add (loc, RValue (*this , {optVal}, optionalOfWrappedTanType));
339
+
340
+ auto result = emitApplyAllocatingInitializer (loc, initDecl,
341
+ std::move (args), outputType, ctxt);
342
+ if (result.isInContext ())
343
+ return ManagedValue::forInContext ();
344
+ return std::move (result).getAsSingleValue (*this , loc);
345
+ }
346
+
347
+ ManagedValue
348
+ SILGenFunction::emitOptionalTangentVectorToTangentVector (SILLocation loc,
349
+ ManagedValue input,
350
+ CanType inputType,
351
+ CanType outputType,
352
+ SGFContext ctxt) {
353
+ // Optional<T>.TangentVector should be a struct with a single
354
+ // Optional<T.TangentVector> property. This is an implementation detail of
355
+ // OptionalDifferentiation.swift
356
+ // TODO: Maybe it would be better to have getters / setters here that we
357
+ // can call and hide this implementation detail?
358
+ StructDecl *optStructDecl = inputType.getStructOrBoundGenericStruct ();
359
+ VarDecl *wrappedValueVar = nullptr ;
360
+ if (optStructDecl) {
361
+ ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties ();
362
+ wrappedValueVar = properties.size () == 1 ? properties[0 ] : nullptr ;
363
+ }
364
+
365
+ EnumDecl *optDecl = wrappedValueVar ?
366
+ wrappedValueVar->getTypeInContext ()->getEnumOrBoundGenericEnum () :
367
+ nullptr ;
368
+
369
+ if (!optStructDecl || optDecl != getASTContext ().getOptionalDecl ())
370
+ llvm_unreachable (" Unexpected type of Optional.TangentVector" );
371
+
372
+ FormalEvaluationScope scope (*this );
373
+ auto wrappedVal = B.createStructExtract (loc, input, wrappedValueVar);
374
+ return emitCheckedGetOptionalValueFrom (loc, wrappedVal,
375
+ /* isImplicitUnwrap*/ true ,
376
+ getTypeLowering (wrappedVal.getType ()),
377
+ ctxt);
378
+ }
379
+
380
+
381
+
298
382
// / Apply this transformation to an arbitrary value.
299
383
RValue Transform::transform (RValue &&input,
300
384
AbstractionPattern inputOrigType,
@@ -689,51 +773,11 @@ ManagedValue Transform::transform(ManagedValue v,
689
773
// case when T == T.TangentVector)
690
774
auto inputTanSpace =
691
775
wrappedType->getAutoDiffTangentSpace (LookUpConformanceInModule ());
692
- if (inputTanSpace && inputTanSpace->getCanonicalType () == inputSubstType) {
693
- auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal ();
694
- // Look up the `Optional<T>.TangentVector.init` declaration.
695
- auto initLookup =
696
- optionalTanDecl->lookupDirect (DeclBaseName::createConstructor ());
697
- ConstructorDecl *constructorDecl = nullptr ;
698
- for (auto *candidate : initLookup) {
699
- auto candidateModule = candidate->getModuleContext ();
700
- if (candidateModule->getName () ==
701
- SGF. getASTContext ().Id_Differentiation ||
702
- candidateModule->isStdlibModule ()) {
703
- assert (!constructorDecl && " Multiple `Optional.TangentVector.init`s" );
704
- constructorDecl = cast<ConstructorDecl>(candidate);
705
- #ifdef NDEBUG
706
- break ;
707
- #endif
708
- }
709
- }
710
- assert (constructorDecl && " No `Optional.TangentVector.init`" );
711
-
712
- // `T.TangentVector`
713
- CanType wrappedTanType = inputTanSpace->getCanonicalType ();
714
- // `Optional<T.TangentVector>`
715
- CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType ();
716
-
717
- const TypeLowering &optTL = SGF.getTypeLowering (optionalOfWrappedTanType);
718
- auto optVal = SGF.emitInjectOptional (Loc, optTL, ctxt,
719
- [&](SGFContext objectCtxt) {
720
- return v;
721
- });
722
- auto *diffProto = SGF.getASTContext ().getProtocol (KnownProtocolKind::Differentiable);
723
- auto diffConf = lookupConformance (wrappedType, diffProto);
724
- assert (!diffConf.isInvalid () && " Missing conformance to `Differentiable`" );
725
- ConcreteDeclRef initDecl (constructorDecl,
726
- SubstitutionMap::get (constructorDecl->getGenericSignature (),
727
- {wrappedType}, {diffConf}));
728
- PreparedArguments args ({AnyFunctionType::Param (optionalOfWrappedTanType)});
729
- args.add (Loc, RValue (SGF, {optVal}, optionalOfWrappedTanType));
730
-
731
- auto result = SGF.emitApplyAllocatingInitializer (Loc, initDecl,
732
- std::move (args), outputSubstType, ctxt);
733
- if (result.isInContext ())
734
- return ManagedValue::forInContext ();
735
- return std::move (result).getAsSingleValue (SGF, Loc);
736
- }
776
+ if (inputTanSpace &&
777
+ inputTanSpace->getCanonicalType () == inputSubstType)
778
+ return SGF.emitTangentVectorToOptionalTangentVector (Loc, v,
779
+ inputSubstType, outputSubstType,
780
+ ctxt);
737
781
}
738
782
739
783
// - Optional<T>.TangentVector to T.TangentVector.
@@ -744,33 +788,11 @@ ManagedValue Transform::transform(ManagedValue v,
744
788
// case when T == T.TangentVector)
745
789
auto outputTanSpace =
746
790
wrappedType->getAutoDiffTangentSpace (LookUpConformanceInModule ());
747
- if (outputTanSpace && outputTanSpace->getCanonicalType () == outputSubstType) {
748
- // Optional<T>.TangentVector should be a struct with a single
749
- // Optional<T.TangentVector> property. This is an implementation detail of
750
- // OptionalDifferentiation.swift
751
- // TODO: Maybe it would be better to have getters / setters here that we
752
- // can call and hide this implementation detail?
753
- StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct ();
754
- VarDecl *wrappedValueVar = nullptr ;
755
- if (optStructDecl) {
756
- ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties ();
757
- wrappedValueVar = properties.size () == 1 ? properties[0 ] : nullptr ;
758
- }
759
-
760
- EnumDecl *optDecl = wrappedValueVar ?
761
- wrappedValueVar->getTypeInContext ()->getEnumOrBoundGenericEnum () :
762
- nullptr ;
763
-
764
- if (!optStructDecl || optDecl != SGF.getASTContext ().getOptionalDecl ())
765
- llvm_unreachable (" Unexpected type of Optional.TangentVector" );
766
-
767
- FormalEvaluationScope scope (SGF);
768
- auto wrappedVal = SGF.B .createStructExtract (Loc, v, wrappedValueVar);
769
- return SGF.emitCheckedGetOptionalValueFrom (Loc, wrappedVal,
770
- /* isImplicitUnwrap*/ true ,
771
- SGF.getTypeLowering (wrappedVal.getType ()),
772
- ctxt);
773
- }
791
+ if (outputTanSpace &&
792
+ outputTanSpace->getCanonicalType () == outputSubstType)
793
+ return SGF.emitOptionalTangentVectorToTangentVector (Loc, v,
794
+ inputSubstType, outputSubstType,
795
+ ctxt);
774
796
}
775
797
776
798
// Should have handled the conversion in one of the cases above.
0 commit comments