diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index e5f6d646eae94..4e6318bbe0250 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -32,6 +32,7 @@ namespace swift { +class AbstractFunctionDecl; class AnyFunctionType; class SourceFile; class SILFunctionType; @@ -247,7 +248,12 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, /// an `inout` parameter type. Used in derivative function type calculation. struct AutoDiffSemanticFunctionResultType { Type type; - bool isInout; + unsigned index : 30; + bool isInout : 1; + bool isWrtParam : 1; + + AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt) + : type(t), index(idx), isInout(inout), isWrtParam(wrt) { } }; /// Key for caching SIL derivative function types. @@ -398,9 +404,6 @@ class DerivativeFunctionTypeError enum class Kind { /// Original function type has no semantic results. NoSemanticResults, - /// Original function type has multiple semantic results. - // TODO(TF-1250): Support function types with multiple semantic results. - MultipleSemanticResults, /// Differentiability parmeter indices are empty. NoDifferentiabilityParameters, /// A differentiability parameter does not conform to `Differentiable`. @@ -429,7 +432,6 @@ class DerivativeFunctionTypeError explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind) : functionType(functionType), kind(kind), value(Value()) { assert(kind == Kind::NoSemanticResults || - kind == Kind::MultipleSemanticResults || kind == Kind::NoDifferentiabilityParameters); }; @@ -572,15 +574,22 @@ namespace autodiff { /// `inout` parameter types. /// /// The function type may have at most two parameter lists. -/// -/// Remaps the original semantic result using `genericEnv`, if specified. -void getFunctionSemanticResultTypes( - AnyFunctionType *functionType, - SmallVectorImpl &result, - GenericEnvironment *genericEnv = nullptr); +void getFunctionSemanticResults( + const AnyFunctionType *functionType, + const IndexSubset *parameterIndices, + SmallVectorImpl &resultTypes); + +/// Returns the indices of semantic results for a given function. +IndexSubset *getFunctionSemanticResultIndices( + const AnyFunctionType *functionType, + const IndexSubset *parameterIndices); + +IndexSubset *getFunctionSemanticResultIndices( + const AbstractFunctionDecl *AFD, + const IndexSubset *parameterIndices); /// Returns the lowered SIL parameter indices for the given AST parameter -/// indices and `AnyfunctionType`. +/// indices and `AnyFunctionType`. /// /// Notable lowering-related changes: /// - AST tuple parameter types are exploded when lowered to SIL. diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 628905882a691..0c4e6a761bc98 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3950,9 +3950,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none, (DescriptiveDeclKind)) ERROR(autodiff_attr_original_void_result,none, "cannot differentiate void function %0", (DeclName)) -ERROR(autodiff_attr_original_multiple_semantic_results,none, - "cannot differentiate functions with both an 'inout' parameter and a " - "result", ()) ERROR(autodiff_attr_result_not_differentiable,none, "can only differentiate functions with results that conform to " "'Differentiable', but %0 does not conform to 'Differentiable'", (Type)) diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 60a49b43e304f..057bfe05b9e18 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -176,39 +176,89 @@ void AnyFunctionType::getSubsetParameters( } } -void autodiff::getFunctionSemanticResultTypes( - AnyFunctionType *functionType, - SmallVectorImpl &result, - GenericEnvironment *genericEnv) { +void autodiff::getFunctionSemanticResults( + const AnyFunctionType *functionType, + const IndexSubset *parameterIndices, + SmallVectorImpl &resultTypes) { auto &ctx = functionType->getASTContext(); - // Remap type in `genericEnv`, if specified. - auto remap = [&](Type type) { - if (!genericEnv) - return type; - return genericEnv->mapTypeIntoContext(type); - }; - // Collect formal result type as a semantic result, unless it is // `Void`. auto formalResultType = functionType->getResult(); if (auto *resultFunctionType = - functionType->getResult()->getAs()) { + functionType->getResult()->getAs()) formalResultType = resultFunctionType->getResult(); + + unsigned resultIdx = 0; + if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) { + // Separate tuple elements into individual results. + if (formalResultType->is()) { + for (auto elt : formalResultType->castTo()->getElements()) { + resultTypes.emplace_back(elt.getType(), resultIdx++, + /*isInout*/ false, /*isWrt*/ false); + } + } else { + resultTypes.emplace_back(formalResultType, resultIdx++, + /*isInout*/ false, /*isWrt*/ false); + } } - if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) - result.push_back({remap(formalResultType), /*isInout*/ false}); - // Collect `inout` parameters as semantic results. - for (auto param : functionType->getParams()) - if (param.isInOut()) - result.push_back({remap(param.getPlainType()), /*isInout*/ true}); - if (auto *resultFunctionType = - functionType->getResult()->getAs()) { - for (auto param : resultFunctionType->getParams()) - if (param.isInOut()) - result.push_back({remap(param.getPlainType()), /*isInout*/ true}); + bool addNonWrts = resultTypes.empty(); + + // Collect wrt `inout` parameters as semantic results + // As an extention, collect all (including non-wrt) inouts as results for + // functions returning void. + auto collectSemanticResults = [&](const AnyFunctionType *functionType, + unsigned curryOffset = 0) { + for (auto paramAndIndex : enumerate(functionType->getParams())) { + if (!paramAndIndex.value().isInOut()) + continue; + + unsigned idx = paramAndIndex.index() + curryOffset; + assert(idx < parameterIndices->getCapacity() && + "invalid parameter index"); + bool isWrt = parameterIndices->contains(idx); + if (addNonWrts || isWrt) + resultTypes.emplace_back(paramAndIndex.value().getPlainType(), + resultIdx, /*isInout*/ true, isWrt); + resultIdx += 1; + } + }; + + if (auto *resultFnType = + functionType->getResult()->getAs()) { + // Here we assume that the input is a function type with curried `Self` + assert(functionType->getNumParams() == 1 && "unexpected function type"); + + collectSemanticResults(resultFnType); + collectSemanticResults(functionType, resultFnType->getNumParams()); + } else + collectSemanticResults(functionType); +} + +IndexSubset * +autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType, + const IndexSubset *parameterIndices) { + auto &ctx = functionType->getASTContext(); + + SmallVector semanticResults; + autodiff::getFunctionSemanticResults(functionType, parameterIndices, + semanticResults); + SmallVector resultIndices; + unsigned cap = 0; + for (const auto& result : semanticResults) { + resultIndices.push_back(result.index); + cap = std::max(cap, result.index + 1U); } + + return IndexSubset::get(ctx, cap, resultIndices); +} + +IndexSubset * +autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD, + const IndexSubset *parameterIndices) { + return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo(), + parameterIndices); } // TODO(TF-874): Simplify this helper. See TF-874 for WIP. @@ -395,9 +445,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const { case Kind::NoSemanticResults: OS << "has no semantic results ('Void' result)"; break; - case Kind::MultipleSemanticResults: - OS << "has multiple semantic results"; - break; case Kind::NoDifferentiabilityParameters: OS << "has no differentiability parameters"; break; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 8fcb0d8e87fcf..48f70c85f824d 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -5548,32 +5548,43 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( getSubsetParameters(parameterIndices, diffParams, /*reverseCurryLevels*/ !makeSelfParamFirst); - // Get the original semantic result type. + // Get the original non-inout semantic result types. SmallVector originalResults; - autodiff::getFunctionSemanticResultTypes(this, originalResults); + autodiff::getFunctionSemanticResults(this, parameterIndices, originalResults); // Error if no original semantic results. if (originalResults.empty()) return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Error if multiple original semantic results. - // TODO(TF-1250): Support functions with multiple semantic results. - if (originalResults.size() > 1) - return llvm::make_error( - this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults); - auto originalResult = originalResults.front(); - auto originalResultType = originalResult.type; - - // Get the original semantic result type's `TangentVector` associated type. - auto resultTan = - originalResultType->getAutoDiffTangentSpace(lookupConformance); - // Error if original semantic result has no tangent space. - if (!resultTan) { - return llvm::make_error( + + // Accumulate non-inout result tangent spaces. + SmallVector resultTanTypes, inoutTanTypes; + for (auto i : range(originalResults.size())) { + auto originalResult = originalResults[i]; + auto originalResultType = originalResult.type; + + // Voids currently have a defined tangent vector, so ignore them. + if (originalResultType->isVoid()) + continue; + + // Get the original semantic result type's `TangentVector` associated type. + // Error if a semantic result has no tangent space. + auto resultTan = + originalResultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, - std::make_pair(originalResultType, /*index*/ 0)); + std::make_pair(originalResultType, unsigned(originalResult.index))); + + if (!originalResult.isInout) + resultTanTypes.push_back(resultTan->getType()); + else if (originalResult.isInout && !originalResult.isWrtParam) + inoutTanTypes.push_back(resultTan->getType()); } - auto resultTanType = resultTan->getType(); + // Treat non-wrt inouts as semantic results for functions returning Void + if (resultTanTypes.empty()) + resultTanTypes = inoutTanTypes; + // Compute the result linear map function type. FunctionType *linearMapType; switch (kind) { @@ -5586,32 +5597,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // // Case 2: original function has a non-wrt `inout` parameter. // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, ...) -> T1.Tan` + // - Differential: `(T0.Tan, ...) -> T1.Tan` // // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` + // - Original: `(T0, inout T1, ...) -> Void` + // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` SmallVector differentialParams; - bool hasInoutDiffParameter = false; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance); // Error if parameter has no tangent space. - if (!paramTan) { + if (!paramTan) return llvm::make_error( this, DerivativeFunctionTypeError::Kind:: NonDifferentiableDifferentiabilityParameter, std::make_pair(paramType, i)); - } + differentialParams.push_back(AnyFunctionType::Param( paramTan->getType(), Identifier(), diffParam.getParameterFlags())); - if (diffParam.isInOut()) - hasInoutDiffParameter = true; } - auto differentialResult = - hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType; + Type differentialResult; + if (resultTanTypes.empty()) { + differentialResult = ctx.TheEmptyTupleType; + } else if (resultTanTypes.size() == 1) { + differentialResult = resultTanTypes.front(); + } else { + SmallVector differentialResults; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + differentialResults.push_back( + TupleTypeElt(resultTanType, Identifier())); + } + differentialResult = TupleType::get(differentialResults, ctx); + } + // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; linearMapType = @@ -5629,25 +5650,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, inout T1, ...) -> Void` // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` // - // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` + // Case 3: original function has wrt `inout` parameters. + // - Original: `(T0, inout T1, ...) -> R` + // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - bool hasInoutDiffParameter = false; + SmallVector inoutParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance); // Error if parameter has no tangent space. - if (!paramTan) { + if (!paramTan) return llvm::make_error( this, DerivativeFunctionTypeError::Kind:: NonDifferentiableDifferentiabilityParameter, std::make_pair(paramType, i)); - } + if (diffParam.isInOut()) { - hasInoutDiffParameter = true; + if (paramType->isVoid()) + continue; + inoutParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -5660,12 +5683,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } else { pullbackResult = TupleType::get(pullbackResults, ctx); } - auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter); - auto pullbackParam = - AnyFunctionType::Param(resultTanType, Identifier(), flags); + // First accumulate non-inout results as pullback parameters. + SmallVector pullbackParams; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + auto flags = ParameterTypeFlags().withInOut(false); + pullbackParams.push_back(AnyFunctionType::Param( + resultTanType, Identifier(), flags)); + } + // Then append inout parameters. + for (auto i : range(inoutParams.size())) { + auto inoutParam = inoutParams[i]; + auto inoutParamType = inoutParam.getPlainType(); + auto inoutParamTan = + inoutParamType->getAutoDiffTangentSpace(lookupConformance); + auto flags = ParameterTypeFlags().withInOut(true); + pullbackParams.push_back(AnyFunctionType::Param( + inoutParamTan->getType(), Identifier(), flags)); + } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; - linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info); + linearMapType = FunctionType::get(pullbackParams, pullbackResult, info); break; } } diff --git a/lib/IRGen/IRGenMangler.h b/lib/IRGen/IRGenMangler.h index 0ebf0992a85da..700716b8aaf65 100644 --- a/lib/IRGen/IRGenMangler.h +++ b/lib/IRGen/IRGenMangler.h @@ -57,9 +57,12 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(func, + derivativeId->getParameterIndices()); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tj"); @@ -86,9 +89,12 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(func, + derivativeId->getParameterIndices()); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tq"); diff --git a/lib/SIL/IR/SILDeclRef.cpp b/lib/SIL/IR/SILDeclRef.cpp index cef7d684530d6..2085a502274ca 100644 --- a/lib/SIL/IR/SILDeclRef.cpp +++ b/lib/SIL/IR/SILDeclRef.cpp @@ -1067,7 +1067,10 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { auto *silParameterIndices = autodiff::getLoweredParameterIndices( derivativeFunctionIdentifier->getParameterIndices(), getDecl()->getInterfaceType()->castTo()); - auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0}); + // FIXME: is this correct in the presence of curried types? + auto *resultIndices = autodiff::getFunctionSemanticResultIndices( + asAutoDiffOriginalFunction().getAbstractFunctionDecl(), + derivativeFunctionIdentifier->getParameterIndices()); AutoDiffConfig silConfig( silParameterIndices, resultIndices, derivativeFunctionIdentifier->getDerivativeGenericSignature()); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 78a0538a211fc..e7d21b34d1418 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -231,15 +231,15 @@ SILFunctionType::getDifferentiabilityParameterIndices() { IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { assert(isDifferentiable() && "Must be a differentiable function"); SmallVector resultIndices; + // Check formal results. for (auto resultAndIndex : enumerate(getResults())) if (resultAndIndex.value().getDifferentiability() != SILResultDifferentiability::NotDifferentiable) resultIndices.push_back(resultAndIndex.index()); + // Check `inout` parameters. for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) - // FIXME(TF-1305): The `getResults().empty()` condition is a hack. - // // Currently, an `inout` parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a @@ -254,10 +254,11 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { // parameters are not treated as differentiability results, unless the // original function has no formal results, in which case all `inout` // parameters are treated as differentiability results. - if (getResults().empty() || + if (resultIndices.empty() || inoutParamAndIndex.value().getDifferentiability() != - SILParameterDifferentiability::NotDifferentiable) + SILParameterDifferentiability::NotDifferentiable) resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + auto numSemanticResults = getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); @@ -371,24 +372,19 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy, /// `inout` parameters, in type order. static void getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, - IndexSubset *&inoutParameterIndices, SmallVectorImpl &originalResults) { - auto &C = functionType->getASTContext(); - SmallVector inoutParamIndices; // Collect original formal results. originalResults.append(functionType->getResults().begin(), functionType->getResults().end()); + // Collect original `inout` parameters. for (auto i : range(functionType->getNumParameters())) { auto param = functionType->getParameters()[i]; - if (!param.isIndirectInOut()) + if (!param.isIndirectMutating()) continue; - inoutParamIndices.push_back(i); - originalResults.push_back( - SILResultInfo(param.getInterfaceType(), ResultConvention::Indirect)); + if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) + originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect); } - inoutParameterIndices = - IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices); } static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig, @@ -566,10 +562,8 @@ static CanSILFunctionType getAutoDiffDifferentialType( SmallVector substReplacements; SmallVector substConformances; - IndexSubset *inoutParamIndices; SmallVector originalResults; - getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices, - originalResults); + getSemanticResults(originalFnTy, parameterIndices, originalResults); SmallVector diffParams; getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); @@ -603,7 +597,7 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialResults.push_back({resultTanType, resultConv}); continue; } - // Handle original `inout` parameter. + // Handle original `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); @@ -650,10 +644,8 @@ static CanSILFunctionType getAutoDiffPullbackType( SmallVector substReplacements; SmallVector substConformances; - IndexSubset *inoutParamIndices; SmallVector originalResults; - getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices, - originalResults); + getSemanticResults(originalFnTy, parameterIndices, originalResults); // Given a type, returns its formal SIL parameter info. auto getTangentParameterConventionForOriginalResult = @@ -745,7 +737,7 @@ static CanSILFunctionType getAutoDiffPullbackType( pullbackParams.push_back({resultTanType, paramConv}); continue; } - // Handle original `inout` parameter. + // Handle `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); diff --git a/lib/SIL/IR/SILSymbolVisitor.cpp b/lib/SIL/IR/SILSymbolVisitor.cpp index 27eb0c41e9be5..9a64e4c4b56ac 100644 --- a/lib/SIL/IR/SILSymbolVisitor.cpp +++ b/lib/SIL/IR/SILSymbolVisitor.cpp @@ -465,20 +465,29 @@ class SILSymbolVisitorImpl : public ASTVisitor { // Add derivative function symbols. for (const auto *differentiableAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = autodiff::getFunctionSemanticResultIndices( + AFD, + differentiableAttr->getParameterIndices()); addDerivativeConfiguration( differentiableAttr->getDifferentiabilityKind(), AFD, AutoDiffConfig(differentiableAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, differentiableAttr->getDerivativeGenericSignature())); + } + for (const auto *derivativeAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = autodiff::getFunctionSemanticResultIndices( + derivativeAttr->getOriginalFunction(AFD->getASTContext()), + derivativeAttr->getParameterIndices()); addDerivativeConfiguration( DifferentiabilityKind::Reverse, derivativeAttr->getOriginalFunction(AFD->getASTContext()), AutoDiffConfig(derivativeAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, AFD->getGenericSignature())); + } addRuntimeDiscoverableAttrGenerators(AFD); @@ -522,13 +531,17 @@ class SILSymbolVisitorImpl : public ASTVisitor { // Add derivative function symbols. for (const auto *differentiableAttr : - ASD->getAttrs().getAttributes()) + ASD->getAttrs().getAttributes()) { + // FIXME: handle other accessors + auto accessorDecl = ASD->getOpaqueAccessor(AccessorKind::Get); addDerivativeConfiguration( differentiableAttr->getDifferentiabilityKind(), - ASD->getOpaqueAccessor(AccessorKind::Get), + accessorDecl, AutoDiffConfig(differentiableAttr->getParameterIndices(), - IndexSubset::get(ASD->getASTContext(), 1, {0}), + autodiff::getFunctionSemanticResultIndices(accessorDecl, + differentiableAttr->getParameterIndices()), differentiableAttr->getDerivativeGenericSignature())); + } } void visitVarDecl(VarDecl *VD) { diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index c70a797ceadc9..52cd4650a6df8 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1255,11 +1255,13 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto *AFD = constant.getAbstractFunctionDecl(); auto emitWitnesses = [&](DeclAttributes &Attrs) { for (auto *diffAttr : Attrs.getAttributes()) { - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); assert((!F->getLoweredFunctionType()->getSubstGenericSignature() || diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(AFD, + diffAttr->getParameterIndices()); auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( AFD->getGenericSignature(), @@ -1288,7 +1290,9 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( origAFD->getGenericSignature(), AFD->getGenericSignature()); - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(origAFD, + derivAttr->getParameterIndices()); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, @@ -1311,6 +1315,7 @@ void SILGenModule::emitDifferentiabilityWitness( auto origSilFnType = originalFunction->getLoweredFunctionType(); auto *silParamIndices = autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType); + // NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has // parameters corresponding to captured variables. These parameters do not // appear in the type of `origFnType`. diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 78d3030fff342..c5c18859490a7 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -4885,10 +4885,12 @@ getWitnessFunctionRef(SILGenFunction &SGF, auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), witness.getDecl()->getInterfaceType()->castTo()); - auto *loweredResultIndices = IndexSubset::get( - SGF.getASTContext(), 1, {0}); // FIXME, set to all results + // FIXME: is this correct in the presence of curried types? + auto *resultIndices = autodiff::getFunctionSemanticResultIndices( + witness.getDecl()->getInterfaceType()->castTo(), + derivativeId->getParameterIndices()); auto diffFn = SGF.B.createDifferentiableFunction( - loc, loweredParamIndices, loweredResultIndices, originalFn); + loc, loweredParamIndices, resultIndices, originalFn); return SGF.B.createDifferentiableFunctionExtract( loc, NormalDifferentiableFunctionTypeComponent(derivativeId->getKind()), diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 0e131bf44e96d..8773f811efd9c 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -549,11 +549,14 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( SILGenFunctionBuilder builder(*this); auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction(); Mangle::ASTMangler mangler; + auto *resultIndices = autodiff::getFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl(), + derivativeId->getParameterIndices()); auto name = mangler.mangleAutoDiffDerivativeFunction( originalFnDeclRef.getAbstractFunctionDecl(), derivativeId->getKind(), AutoDiffConfig(derivativeId->getParameterIndices(), - IndexSubset::get(getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()), /*isVTableThunk*/ true); auto *thunk = builder.getOrCreateFunction( @@ -573,7 +576,12 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), derivativeFnDecl->getInterfaceType()->castTo()); - auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0}); + // FIXME: Do we need to lower the result indices? Likely yes. + auto *loweredResultIndices = + autodiff::getFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl(), + derivativeId->getParameterIndices() + ); auto diffFn = SGF.B.createDifferentiableFunction( loc, loweredParamIndices, loweredResultIndices, originalFn); auto derivativeFn = SGF.B.createDifferentiableFunctionExtract( diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index ec1bf4e79890e..836e586c219c4 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -232,7 +232,7 @@ void collectMinimalIndicesForFunctionCall( auto ¶m = paramAndIdx.value(); if (!param.isIndirectMutating()) continue; - unsigned idx = paramAndIdx.index(); + unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); auto inoutArg = ai->getArgument(idx); results.push_back(inoutArg); resultIndices.push_back(inoutParamResultIndex++); @@ -492,10 +492,6 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( SILModule &module, SILFunction *original, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices) { - // AST differentiability witnesses always have a single result. - if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0)) - return nullptr; - // Explicit differentiability witnesses only exist on SIL functions that come // from AST functions. auto *originalAFD = findAbstractFunctionDecl(original); diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 2a6b1cdec3833..52b505681ed43 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -474,6 +474,12 @@ getOrCreateSubsetParametersThunkForLinearMap( return mappedIndex; }; + auto toIndirectResultsIter = thunk->getIndirectResults().begin(); + auto useNextIndirectResult = [&]() { + assert(toIndirectResultsIter != thunk->getIndirectResults().end()); + arguments.push_back(*toIndirectResultsIter++); + }; + switch (kind) { // Differential arguments are: // - All indirect results, followed by: @@ -482,9 +488,29 @@ getOrCreateSubsetParametersThunkForLinearMap( // indices). // - Zeros (when parameter is not in desired indices). case AutoDiffDerivativeFunctionKind::JVP: { - // Forward all indirect results. - arguments.append(thunk->getIndirectResults().begin(), - thunk->getIndirectResults().end()); + unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults(); + // Forward desired indirect results + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numIndirectResults) + break; + + auto resultInfo = linearMapType->getResults()[idx]; + assert(idx < linearMapType->getNumResults()); + + // Forward result argument in case we do not need to thunk it away + if (desiredConfig.resultIndices->contains(idx)) { + useNextIndirectResult(); + continue; + } + + // Otherwise, allocate and use an uninitialized indirect result + auto *indirectResult = builder.createAllocStack( + loc, resultInfo.getSILStorageInterfaceType()); + localAllocations.push_back(indirectResult); + arguments.push_back(indirectResult); + } + assert(toIndirectResultsIter == thunk->getIndirectResults().end()); + auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin(); auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); }; // Iterate over actual indices. @@ -509,10 +535,6 @@ getOrCreateSubsetParametersThunkForLinearMap( // - Zeros (when parameter is not in desired indices). // - All actual arguments. case AutoDiffDerivativeFunctionKind::VJP: { - auto toIndirectResultsIter = thunk->getIndirectResults().begin(); - auto useNextIndirectResult = [&]() { - arguments.push_back(*toIndirectResultsIter++); - }; // Collect pullback arguments. unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { @@ -541,8 +563,17 @@ getOrCreateSubsetParametersThunkForLinearMap( arguments.push_back(indirectResult); } // Forward all actual non-indirect-result arguments. - arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(), - thunk->getArgumentsWithoutIndirectResults().end() - 1); + auto thunkArgs = thunk->getArgumentsWithoutIndirectResults(); + // Slice out the function to be called + thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1); + unsigned thunkArg = 0; + for (unsigned idx : *actualConfig.resultIndices) { + // Forward result argument in case we do not need to thunk it away + if (desiredConfig.resultIndices->contains(idx)) + arguments.push_back(thunkArgs[thunkArg++]); + else // otherwise, zero it out + buildZeroArgument(linearMapType->getParameters()[arguments.size()]); + } break; } } @@ -552,10 +583,33 @@ getOrCreateSubsetParametersThunkForLinearMap( auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments); // If differential thunk, deallocate local allocations and directly return - // `apply` result. + // `apply` result (if it is desired). if (kind == AutoDiffDerivativeFunctionKind::JVP) { + SmallVector differentialDirectResults; + extractAllElements(ai, builder, differentialDirectResults); + SmallVector allResults; + collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults); + unsigned numResults = thunk->getConventions().getNumDirectSILResults() + + thunk->getConventions().getNumDirectSILResults(); + SmallVector results; + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numResults) + break; + + auto result = allResults[idx]; + if (desiredConfig.isWrtResult(idx)) + results.push_back(result); + else { + if (result->getType().isAddress()) + builder.emitDestroyAddrAndFold(loc, result); + else + builder.emitDestroyValueOperation(loc, result); + } + } + cleanupValues(); - builder.createReturn(loc, ai); + auto result = joinElements(results, builder, loc); + builder.createReturn(loc, result); return {thunk, interfaceSubs}; } @@ -772,8 +826,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( /*withoutActuallyEscaping*/ false); } assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() == - 1); + origFnType->getNumIndirectMutatingParameters() > + 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { auto result = diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index bd20225556664..8219ed09dd444 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -750,15 +750,16 @@ static SILFunction *createEmptyVJP(ADContext &context, SILDifferentiabilityWitness *witness, IsSerialized_t isSerialized) { auto original = witness->getOriginalFunction(); + auto config = witness->getConfig(); LLVM_DEBUG({ auto &s = getADDebugStream(); - s << "Creating VJP:\n\t"; + s << "Creating VJP for " << original->getName() << ":\n\t"; s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; + s << "Config: " << config << "\n\t"; }); auto &module = context.getModule(); auto originalTy = original->getLoweredFunctionType(); - auto config = witness->getConfig(); // === Create an empty VJP. === Mangle::DifferentiationMangler mangler; @@ -794,15 +795,16 @@ static SILFunction *createEmptyJVP(ADContext &context, SILDifferentiabilityWitness *witness, IsSerialized_t isSerialized) { auto original = witness->getOriginalFunction(); + auto config = witness->getConfig(); LLVM_DEBUG({ auto &s = getADDebugStream(); - s << "Creating JVP:\n\t"; + s << "Creating JVP for " << original->getName() << ":\n\t"; s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; + s << "Config: " << config << "\n\t"; }); auto &module = context.getModule(); auto originalTy = original->getLoweredFunctionType(); - auto config = witness->getConfig(); Mangle::DifferentiationMangler mangler; auto jvpName = mangler.mangleDerivativeFunction( diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 9b21805409618..ccd3042d00f29 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5509,12 +5509,6 @@ bool resolveDifferentiableAttrDifferentiabilityParameters( original->getName()) .highlight(original->getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(original->getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); @@ -5666,7 +5660,19 @@ typecheckDifferentiableAttrforDecl(AbstractFunctionDecl *original, } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + SmallVector semanticResults; + + // Compute the derivative function type. + auto originalFnRemappedTy = original->getInterfaceType()->castTo(); + if (auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment()) + originalFnRemappedTy = + derivativeGenEnv->mapTypeIntoContext(originalFnRemappedTy) + ->castTo(); + + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(originalFnRemappedTy, + resolvedDiffParamIndices); + original->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -6124,12 +6130,6 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { originalAFD->getName()) .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); @@ -6219,7 +6219,9 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(Ctx, 1, {0}); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(originalAFD, + resolvedDiffParamIndices); originalAFD->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivative->getGenericSignature()}); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 6a8be9b91a7de..21856b6696e41 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -492,7 +492,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, witness->getAttrs().add(newAttr); success = true; // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(witnessAFD, + newAttr->getParameterIndices()); witnessAFD->addDerivativeFunctionConfiguration( {newAttr->getParameterIndices(), resultIndices, newAttr->getDerivativeGenericSignature()}); diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index 98e7cf39b4641..238cbcd0ac8e5 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -722,9 +722,10 @@ void ModuleFile::loadDerivativeFunctionConfigurations( } auto derivativeGenSig = derivativeGenSigOrError.get(); // NOTE(TF-1038): Result indices are currently unsupported in derivative - // registration attributes. In the meantime, always use `{0}` (wrt the - // first and only result). - auto resultIndices = IndexSubset::get(ctx, 1, {0}); + // registration attributes. In the meantime, always use all results. + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(originalAFD, + parameterIndices); results.insert({parameterIndices, resultIndices, derivativeGenSig}); } } diff --git a/test/AutoDiff/SILGen/inout_differentiability_witness.swift b/test/AutoDiff/SILGen/inout_differentiability_witness.swift new file mode 100644 index 0000000000000..e49b4e92a947d --- /dev/null +++ b/test/AutoDiff/SILGen/inout_differentiability_witness.swift @@ -0,0 +1,60 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s +import _Differentiation + +public struct NonDiffableStruct {} + +public struct DiffableStruct : Differentiable {} + +@differentiable(reverse) +func test1(x: Int, y: inout NonDiffableStruct, z: Float) -> Float { return 42.0 } + +@differentiable(reverse) +func test2(x: Int, y: inout DiffableStruct, z: Float) { } + +@differentiable(reverse) +func test3(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return (42.0, 42.0) } + +@differentiable(reverse, wrt: y) +func test4(x: Int, y: inout DiffableStruct, z: Float) -> Void { } + +@differentiable(reverse, wrt: z) +func test5(x: Int, y: inout DiffableStruct, z: Float) -> Void { } + +@differentiable(reverse, wrt: (y, z)) +func test6(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return (42.0, 42.0) } + +// CHECK-LABEL: differentiability witness for test1(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftF : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> Float { +// CHECK: jvp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: vjp: @$s31inout_differentiability_witness5test11x1y1zSfSi_AA17NonDiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout NonDiffableStruct, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: } + +// CHECK-LABEL: differentiability witness for test2(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0] @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJfUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test21x1y1zySi_AA14DiffableStructVzSftFTJrUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float +// CHECK: } + +// CHECK-LABEL: differentiability witness for test3(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 2] @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { +// CHECK: jvp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) +// CHECK: vjp: @$s31inout_differentiability_witness5test31x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) +// CHECK: } + +// CHECK-LABEL: differentiability witness for test4(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1] [results 0] @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJfUSUpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test41x1y1zySi_AA14DiffableStructVzSftFTJrUSUpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> () +// CHECK: } + +// CHECK-LABEL: differentiability witness for test5(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (Float) -> @out DiffableStruct.TangentVector +// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@in_guaranteed DiffableStruct.TangentVector) -> Float +// CHECK: } + +// CHECK-LABEL: differentiability witness for test6(x:y:z:) +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0 1 2] @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float) { +// CHECK: jvp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJfUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> (Float, Float)) +// CHECK: vjp: @$s31inout_differentiability_witness5test61x1y1zSf_SftSi_AA14DiffableStructVzSftFTJrUSSpSSSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> (Float, Float, @owned @callee_guaranteed (Float, Float, @inout DiffableStruct.TangentVector) -> Float) +// CHECK: } diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift index 59ec26ef9bd08..8942432a0a2a9 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift @@ -1,3 +1,25 @@ +import _Differentiation + public struct Struct { public func method(_ x: Float) -> Float { x } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +public func swap(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} + +@differentiable(reverse) +public func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +public func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = tmp + }) +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index 7947518708ad3..5485a5f9a68b7 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -11,3 +11,18 @@ extension Struct: Differentiable { (x, { $0 }) } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +func multiply_swap(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swap(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} + +@differentiable(reverse) +func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 310c1896bad6e..8d0f61b7e34e0 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -749,13 +749,19 @@ extension ProtocolRequirementDerivative { func multipleSemanticResults(_ x: inout Float) -> Float { return x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(x: inout Float) -> ( - value: Float, pullback: (Float) -> Float -) { - return (multipleSemanticResults(&x), { $0 }) + value: Float, pullback: (Float, inout Float) -> Void +) { fatalError() } + +func inoutNonDifferentiableResult(_ x: inout Float) -> Int { + return 5 } +// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} +@derivative(of: inoutNonDifferentiableResult) +func vjpInoutNonDifferentiableResult(x: inout Float) -> ( + value: Int, pullback: (inout Float) -> Void +) { fatalError() } struct InoutParameters: Differentiable { typealias TangentVector = DummyTangentVector @@ -888,17 +894,32 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {} extension InoutParameters { func multipleSemanticResults(_ x: inout Float) -> Float { x } - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} - @derivative(of: multipleSemanticResults) + @derivative(of: multipleSemanticResults, wrt: x) func vjpMultipleSemanticResults(_ x: inout Float) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float, inout Float) -> Void ) { fatalError() } func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} - @derivative(of: inoutVoid) + @derivative(of: inoutVoid, wrt: (x, void)) func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float) -> Float + ) { fatalError() } +} + +// Test tuple results. + +extension InoutParameters { + func tupleResults(_ x: Float) -> (Float, Float) { (x, x) } + @derivative(of: tupleResults, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> Float + ) { fatalError() } + + func tupleResultsInt(_ x: Float) -> (Int, Float) { (1, x) } + // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} + @derivative(of: tupleResultsInt, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Int, Float), pullback: (Float) -> Float ) { fatalError() } } diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 0cd6fa5b1bdb1..41b24ebbf7b62 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -528,7 +528,6 @@ func two9(x: Float, y: Float) -> Float { func inout1(x: Float, y: inout Float) -> Void { let _ = x + y } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func inout2(x: Float, y: inout Float) -> Float { let _ = x + y @@ -670,11 +669,9 @@ final class FinalClass: Differentiable { @differentiable(reverse, wrt: y) func inoutVoid(x: Float, y: inout Float) {} -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) func multipleSemanticResults(_ x: inout Float) -> Float { x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func swap(x: inout Float, y: inout Float) {} @@ -683,11 +680,23 @@ struct InoutParameters: Differentiable { mutating func move(by _: TangentVector) {} } +extension NonDiffableStruct { + // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'NonDiffableStruct' does not conform to 'Differentiable'}} + @differentiable(reverse) + static func nondiffResult(x: Int, y: inout NonDiffableStruct, z: Float) {} + + @differentiable(reverse) + static func diffResult(x: Int, y: inout NonDiffableStruct, z: Float) -> Float {} + + // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'NonDiffableStruct' does not conform to 'Differentiable'}} + @differentiable(reverse, wrt: (y, z)) + static func diffResult2(x: Int, y: inout NonDiffableStruct, z: Float) -> Float {} +} + extension InoutParameters { @differentiable(reverse) static func staticMethod(_ lhs: inout Self, rhs: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {} } @@ -696,11 +705,32 @@ extension InoutParameters { @differentiable(reverse) mutating func mutatingMethod(_ other: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) mutating func mutatingMethod(_ other: Self) -> Self {} } +// Test tuple results. + +extension InoutParameters { + @differentiable(reverse) + static func tupleResults(_ x: Self) -> (Self, Self) {} + + // Int does not conform to Differentiable + // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} + @differentiable(reverse) + static func tupleResultsInt(_ x: Self) -> (Int, Self) {} + + // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} + @differentiable(reverse) + static func tupleResultsInt2(_ x: Self) -> (Self, Int) {} + + @differentiable(reverse) + static func tupleResultsFloat(_ x: Self) -> (Float, Self) {} + + @differentiable(reverse) + static func tupleResultsFloat2(_ x: Self) -> (Self, Float) {} +} + // Test accessors: `set`, `_read`, `_modify`. struct Accessors: Differentiable { diff --git a/test/AutoDiff/Serialization/derivative_attr.swift b/test/AutoDiff/Serialization/derivative_attr.swift index 91677baa80e25..c41c0a36d1a50 100644 --- a/test/AutoDiff/Serialization/derivative_attr.swift +++ b/test/AutoDiff/Serialization/derivative_attr.swift @@ -37,6 +37,26 @@ func derivativeTop2( (y, { (dx, dy) in dy }) } +// Test top-level inout functions. + +func topInout1(_ x: inout S) {} + +// CHECK: @derivative(of: topInout1, wrt: x) +@derivative(of: topInout1) +func derivativeTopInout1(_ x: inout S) -> (value: Void, pullback: (inout S) -> Void) { + fatalError() +} + +func topInout2(_ x: inout S) -> S { + x +} + +// CHECK: @derivative(of: topInout2, wrt: x) +@derivative(of: topInout2) +func derivativeTopInout2(_ x: inout S) -> (value: S, pullback: (S, inout S) -> Void) { + fatalError() +} + // Test instance methods. extension S { diff --git a/test/AutoDiff/Serialization/differentiable_attr.swift b/test/AutoDiff/Serialization/differentiable_attr.swift index b8c83362bd813..e09f7541caf90 100644 --- a/test/AutoDiff/Serialization/differentiable_attr.swift +++ b/test/AutoDiff/Serialization/differentiable_attr.swift @@ -43,6 +43,29 @@ func testWrtClause(x: Float, y: Float) -> Float { return x } +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInout(x: inout Float) +@differentiable(reverse) +func testInout(x: inout Float) { + x = x * 2.0 +} + +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInoutResult(x: inout Float) -> Float +@differentiable(reverse) +func testInoutResult(x: inout Float) -> Float { + x = x * 2.0 + return x +} + +// CHECK: @differentiable(reverse, wrt: (x, y)) +// CHECK-NEXT: func testMultipleInout(x: inout Float, y: inout Float) +@differentiable(reverse) +func testMultipleInout(x: inout Float, y: inout Float) { + x = x * y + y = x +} + struct InstanceMethod : Differentiable { // CHECK: @differentiable(reverse, wrt: (self, y)) // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float diff --git a/test/AutoDiff/Serialization/differentiable_function.swift b/test/AutoDiff/Serialization/differentiable_function.swift index 316a0a6eca40d..e31d874bb8920 100644 --- a/test/AutoDiff/Serialization/differentiable_function.swift +++ b/test/AutoDiff/Serialization/differentiable_function.swift @@ -15,3 +15,15 @@ func b(_ f: @differentiable(_linear) (Float) -> Float) {} func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) {} // CHECK: func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) + +func d(_ f: @differentiable(reverse) (inout Float) -> ()) {} +// CHECK: func d(_ f: @differentiable(reverse) (inout Float) -> ()) + +func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) {} +// CHECK: func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) + +func f(_ f: @differentiable(reverse) (inout Float) -> Float) {} +// CHECK: func f(_ f: @differentiable(reverse) (inout Float) -> Float) + +func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) {} +// CHECK: func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index e76bdb610b717..46374e6b2cb38 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -121,6 +121,69 @@ SimpleMathTests.test("MultipleResults") { expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapAndReturnProduct)) } +// Test function with multiple `inout` parameters and a custom pullback. +@differentiable(reverse) +func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = tmp + }) +} + +SimpleMathTests.test("MultipleResultsWithCustomPullback") { + func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom)) +} + +// Test functions returning tuples. +@differentiable(reverse) +func swapTuple(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} + +@differentiable(reverse) +func swapTupleCustom(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} +@derivative(of: swapTupleCustom) +func vjpSwapTupleCustom(_ x: Float, _ y: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> (Float, Float) +) { + return (swapTupleCustom(x, y), {v1, v2 in + return (v2, v1) + }) +} + +SimpleMathTests.test("ReturningTuples") { + func multiply_swapTuple(_ x: Float, _ y: Float) -> Float { + let result = swapTuple(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTuple)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTuple)) + + func multiply_swapTupleCustom(_ x: Float, _ y: Float) -> Float { + let result = swapTupleCustom(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTupleCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTupleCustom)) +} + SimpleMathTests.test("CaptureLocal") { let z: Float = 10 func foo(_ x: Float) -> Float {