Skip to content

Commit 03334a8

Browse files
authored
[AutoDiff] Generalize handling of semantic result parameters (#67230)
Introduce the notion of "semantic result parameter". Handle differentiation of inouts via semantic result parameter abstraction. Do not consider non-wrt semantic result parameters as semantic results Fixes #67174
1 parent bb6df83 commit 03334a8

24 files changed

+365
-281
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
246246
return s;
247247
}
248248

249-
/// A semantic function result type: either a formal function result type or
250-
/// an `inout` parameter type. Used in derivative function type calculation.
249+
/// A semantic function result type: either a formal function result type or a
250+
/// semantic result (an `inout`) parameter type. Used in derivative function type
251+
/// calculation.
251252
struct AutoDiffSemanticFunctionResultType {
252253
Type type;
253254
unsigned index : 30;
254-
bool isInout : 1;
255-
bool isWrtParam : 1;
255+
bool isSemanticResultParameter : 1;
256256

257-
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
258-
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
257+
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param)
258+
: type(t), index(idx), isSemanticResultParameter(param) { }
259259
};
260260

261261
/// Key for caching SIL derivative function types.

include/swift/AST/Types.h

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase {
31923192
/// Whether the parameter is marked '@noDerivative'.
31933193
bool isNoDerivative() const { return Flags.isNoDerivative(); }
31943194

3195+
/// Whether the parameter might be a semantic result for autodiff purposes.
3196+
/// This includes inout parameters.
3197+
bool isAutoDiffSemanticResult() const {
3198+
return isInOut();
3199+
}
3200+
31953201
ValueOwnership getValueOwnership() const {
31963202
return Flags.getValueOwnership();
31973203
}
@@ -3509,19 +3515,16 @@ class AnyFunctionType : public TypeBase {
35093515
/// Preconditions:
35103516
/// - Parameters corresponding to parameter indices must conform to
35113517
/// `Differentiable`.
3512-
/// - There is one semantic function result type: either the formal original
3513-
/// result or an `inout` parameter. It must conform to `Differentiable`.
3518+
/// - There are semantic function result type: either the formal original
3519+
/// result or a "wrt" semantic result parameter.
35143520
///
35153521
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
35163522
/// "wrt" result derivative.
35173523
///
35183524
/// - Case 1: original function has no `inout` parameters.
35193525
/// - Original: `(T0, T1, ...) -> R`
35203526
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3521-
/// - Case 2: original function has a non-wrt `inout` parameter.
3522-
/// - Original: `(T0, inout T1, ...) -> Void`
3523-
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
3524-
/// - Case 3: original function has a wrt `inout` parameter.
3527+
/// - Case 2: original function has a wrt `inout` parameter.
35253528
/// - Original: `(T0, inout T1, ...) -> Void`
35263529
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
35273530
///
@@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase {
35313534
/// - Case 1: original function has no `inout` parameters.
35323535
/// - Original: `(T0, T1, ...) -> R`
35333536
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3534-
/// - Case 2: original function has a non-wrt `inout` parameter.
3535-
/// - Original: `(T0, inout T1, ...) -> Void`
3536-
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3537-
/// - Case 3: original function has a wrt `inout` parameter.
3537+
/// - Case 2: original function has a wrt `inout` parameter.
35383538
/// - Original: `(T0, inout T1, ...) -> Void`
35393539
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
35403540
///
@@ -4101,6 +4101,9 @@ class SILParameterInfo {
41014101
return getConvention() == ParameterConvention::Indirect_Inout
41024102
|| getConvention() == ParameterConvention::Indirect_InoutAliasable;
41034103
}
4104+
bool isAutoDiffSemanticResult() const {
4105+
return isIndirectMutating();
4106+
}
41044107

41054108
bool isPack() const {
41064109
return isPackParameter(getConvention());
@@ -4836,6 +4839,37 @@ class SILFunctionType final
48364839
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
48374840
}
48384841

4842+
struct AutoDiffSemanticResultsParameterFilter {
4843+
bool operator()(SILParameterInfo param) const {
4844+
return param.isAutoDiffSemanticResult();
4845+
}
4846+
};
4847+
4848+
using AutoDiffSemanticResultsParameterIter =
4849+
llvm::filter_iterator<const SILParameterInfo *,
4850+
AutoDiffSemanticResultsParameterFilter>;
4851+
using AutoDiffSemanticResultsParameterRange =
4852+
iterator_range<AutoDiffSemanticResultsParameterIter>;
4853+
4854+
/// A range of SILParameterInfo for all semantic results parameters.
4855+
AutoDiffSemanticResultsParameterRange
4856+
getAutoDiffSemanticResultsParameters() const {
4857+
return llvm::make_filter_range(getParameters(),
4858+
AutoDiffSemanticResultsParameterFilter());
4859+
}
4860+
4861+
/// Returns the number of semantic results parameters.
4862+
unsigned getNumAutoDiffSemanticResultsParameters() const {
4863+
return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter());
4864+
}
4865+
4866+
/// Returns the number of function potential semantic results:
4867+
/// * Usual results
4868+
/// * Inout parameters
4869+
unsigned getNumAutoDiffSemanticResults() const {
4870+
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
4871+
}
4872+
48394873
/// Get the generic signature that the component types are specified
48404874
/// in terms of, if any.
48414875
CanGenericSignature getSubstGenericSignature() const {

include/swift/SIL/ApplySite.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,18 @@ class FullApplySite : public ApplySite {
681681
llvm_unreachable("invalid apply kind");
682682
}
683683

684+
AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
685+
switch (getKind()) {
686+
case FullApplySiteKind::ApplyInst:
687+
return cast<ApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
688+
case FullApplySiteKind::TryApplyInst:
689+
return cast<TryApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
690+
case FullApplySiteKind::BeginApplyInst:
691+
return cast<BeginApplyInst>(getInstruction())->getAutoDiffSemanticResultArguments();
692+
}
693+
llvm_unreachable("invalid apply kind");
694+
}
695+
684696
/// Returns true if \p op is the callee operand of this apply site
685697
/// and not an argument operand.
686698
bool isCalleeOperand(const Operand &op) const {

include/swift/SIL/SILInstruction.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,6 +2785,25 @@ struct OperandToInoutArgument {
27852785
using InoutArgumentRange =
27862786
OptionalTransformRange<IntRange<size_t>, OperandToInoutArgument>;
27872787

2788+
/// Predicate used to filter AutoDiffSemanticResultArgumentRange.
2789+
struct OperandToAutoDiffSemanticResultArgument {
2790+
ArrayRef<SILParameterInfo> paramInfos;
2791+
OperandValueArrayRef arguments;
2792+
OperandToAutoDiffSemanticResultArgument(ArrayRef<SILParameterInfo> paramInfos,
2793+
OperandValueArrayRef arguments)
2794+
: paramInfos(paramInfos), arguments(arguments) {
2795+
assert(paramInfos.size() == arguments.size());
2796+
}
2797+
llvm::Optional<SILValue> operator()(size_t i) const {
2798+
if (paramInfos[i].isAutoDiffSemanticResult())
2799+
return arguments[i];
2800+
return llvm::None;
2801+
}
2802+
};
2803+
2804+
using AutoDiffSemanticResultArgumentRange =
2805+
OptionalTransformRange<IntRange<size_t>, OperandToAutoDiffSemanticResultArgument>;
2806+
27882807
/// The partial specialization of ApplyInstBase for full applications.
27892808
/// Adds some methods relating to 'self' and to result types that don't
27902809
/// make sense for partial applications.
@@ -2894,6 +2913,16 @@ class ApplyInstBase<Impl, Base, true>
28942913
impl.getArgumentsWithoutIndirectResults()));
28952914
}
28962915

2916+
/// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`)
2917+
/// arguments passed to the instruction.
2918+
AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const {
2919+
auto &impl = asImpl();
2920+
return AutoDiffSemanticResultArgumentRange(
2921+
indices(getArgumentsWithoutIndirectResults()),
2922+
OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(),
2923+
impl.getArgumentsWithoutIndirectResults()));
2924+
}
2925+
28972926
bool hasSemantics(StringRef semanticsString) const {
28982927
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
28992928
}

lib/AST/AutoDiff.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,32 +199,28 @@ void autodiff::getFunctionSemanticResults(
199199
if (formalResultType->is<TupleType>()) {
200200
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
201201
resultTypes.emplace_back(elt.getType(), resultIdx++,
202-
/*isInout*/ false, /*isWrt*/ false);
202+
/*isParameter*/ false);
203203
}
204204
} else {
205205
resultTypes.emplace_back(formalResultType, resultIdx++,
206-
/*isInout*/ false, /*isWrt*/ false);
206+
/*isParameter*/ false);
207207
}
208208
}
209209

210-
bool addNonWrts = resultTypes.empty();
211-
212-
// Collect wrt `inout` parameters as semantic results
213-
// As an extention, collect all (including non-wrt) inouts as results for
214-
// functions returning void.
210+
// Collect wrt semantic result (`inout`) parameters as
211+
// semantic results
215212
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
216213
unsigned curryOffset = 0) {
217214
for (auto paramAndIndex : enumerate(functionType->getParams())) {
218-
if (!paramAndIndex.value().isInOut())
215+
if (!paramAndIndex.value().isAutoDiffSemanticResult())
219216
continue;
220217

221218
unsigned idx = paramAndIndex.index() + curryOffset;
222219
assert(idx < parameterIndices->getCapacity() &&
223220
"invalid parameter index");
224-
bool isWrt = parameterIndices->contains(idx);
225-
if (addNonWrts || isWrt)
221+
if (parameterIndices->contains(idx))
226222
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
227-
resultIdx, /*isInout*/ true, isWrt);
223+
resultIdx, /*isParameter*/ true);
228224
resultIdx += 1;
229225
}
230226
};

lib/AST/Type.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55585558
return llvm::make_error<DerivativeFunctionTypeError>(
55595559
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
55605560

5561-
// Accumulate non-inout result tangent spaces.
5561+
// Accumulate non-semantic result tangent spaces.
55625562
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
55635563
for (auto i : range(originalResults.size())) {
55645564
auto originalResult = originalResults[i];
@@ -5577,16 +5577,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55775577
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
55785578
std::make_pair(originalResultType, unsigned(originalResult.index)));
55795579

5580-
if (!originalResult.isInout)
5580+
if (!originalResult.isSemanticResultParameter)
55815581
resultTanTypes.push_back(resultTan->getType());
5582-
else if (originalResult.isInout && !originalResult.isWrtParam)
5583-
inoutTanTypes.push_back(resultTan->getType());
55845582
}
55855583

5586-
// Treat non-wrt inouts as semantic results for functions returning Void
5587-
if (resultTanTypes.empty())
5588-
resultTanTypes = inoutTanTypes;
5589-
55905584
// Compute the result linear map function type.
55915585
FunctionType *linearMapType;
55925586
switch (kind) {
@@ -5597,11 +5591,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55975591
// - Original: `(T0, T1, ...) -> R`
55985592
// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
55995593
//
5600-
// Case 2: original function has a non-wrt `inout` parameter.
5601-
// - Original: `(T0, inout T1, ...) -> Void`
5602-
// - Differential: `(T0.Tan, ...) -> T1.Tan`
5603-
//
5604-
// Case 3: original function has a wrt `inout` parameter.
5594+
// Case 2: original function has a wrt `inout` parameter.
56055595
// - Original: `(T0, inout T1, ...) -> Void`
56065596
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
56075597
SmallVector<AnyFunctionType::Param, 4> differentialParams;
@@ -5648,15 +5638,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56485638
// - Original: `(T0, T1, ...) -> R`
56495639
// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
56505640
//
5651-
// Case 2: original function has a non-wrt `inout` parameter.
5652-
// - Original: `(T0, inout T1, ...) -> Void`
5653-
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
5654-
//
5655-
// Case 3: original function has wrt `inout` parameters.
5641+
// Case 2: original function has wrt `inout` parameters.
56565642
// - Original: `(T0, inout T1, ...) -> R`
56575643
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
56585644
SmallVector<TupleTypeElt, 4> pullbackResults;
5659-
SmallVector<AnyFunctionType::Param, 2> inoutParams;
5645+
SmallVector<AnyFunctionType::Param, 2> semanticResultParams;
56605646
for (auto i : range(diffParams.size())) {
56615647
auto diffParam = diffParams[i];
56625648
auto paramType = diffParam.getPlainType();
@@ -5669,10 +5655,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56695655
NonDifferentiableDifferentiabilityParameter,
56705656
std::make_pair(paramType, i));
56715657

5672-
if (diffParam.isInOut()) {
5658+
if (diffParam.isAutoDiffSemanticResult()) {
56735659
if (paramType->isVoid())
56745660
continue;
5675-
inoutParams.push_back(diffParam);
5661+
semanticResultParams.push_back(diffParam);
56765662
continue;
56775663
}
56785664
pullbackResults.emplace_back(paramTan->getType());
@@ -5693,22 +5679,23 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56935679
pullbackParams.push_back(AnyFunctionType::Param(
56945680
resultTanType, Identifier(), flags));
56955681
}
5696-
// Then append inout parameters.
5697-
for (auto i : range(inoutParams.size())) {
5698-
auto inoutParam = inoutParams[i];
5699-
auto inoutParamType = inoutParam.getPlainType();
5700-
auto inoutParamTan =
5701-
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
5682+
// Then append semantic result parameters.
5683+
for (auto i : range(semanticResultParams.size())) {
5684+
auto semanticResultParam = semanticResultParams[i];
5685+
auto semanticResultParamType = semanticResultParam.getPlainType();
5686+
auto semanticResultParamTan =
5687+
semanticResultParamType->getAutoDiffTangentSpace(lookupConformance);
57025688
auto flags = ParameterTypeFlags().withInOut(true);
57035689
pullbackParams.push_back(AnyFunctionType::Param(
5704-
inoutParamTan->getType(), Identifier(), flags));
5690+
semanticResultParamTan->getType(), Identifier(), flags));
57055691
}
57065692
// FIXME: Verify ExtInfo state is correct, not working by accident.
57075693
FunctionType::ExtInfo info;
57085694
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
57095695
break;
57105696
}
57115697
}
5698+
57125699
assert(linearMapType && "Expected linear map type");
57135700
return linearMapType;
57145701
}

0 commit comments

Comments
 (0)