Skip to content

[AutoDiff] Supporting differentiable functions with multiple semantic results #66873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

namespace swift {

class AbstractFunctionDecl;
class AnyFunctionType;
class SourceFile;
class SILFunctionType;
Expand Down Expand Up @@ -247,7 +248,12 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
/// an `inout` parameter type. Used in derivative function type calculation.
struct AutoDiffSemanticFunctionResultType {
Type type;
bool isInout;
unsigned index : 30;
bool isInout : 1;
bool isWrtParam : 1;

AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
};

/// Key for caching SIL derivative function types.
Expand Down Expand Up @@ -398,9 +404,6 @@ class DerivativeFunctionTypeError
enum class Kind {
/// Original function type has no semantic results.
NoSemanticResults,
/// Original function type has multiple semantic results.
// TODO(TF-1250): Support function types with multiple semantic results.
MultipleSemanticResults,
/// Differentiability parmeter indices are empty.
NoDifferentiabilityParameters,
/// A differentiability parameter does not conform to `Differentiable`.
Expand Down Expand Up @@ -429,7 +432,6 @@ class DerivativeFunctionTypeError
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
: functionType(functionType), kind(kind), value(Value()) {
assert(kind == Kind::NoSemanticResults ||
kind == Kind::MultipleSemanticResults ||
kind == Kind::NoDifferentiabilityParameters);
};

Expand Down Expand Up @@ -572,15 +574,22 @@ namespace autodiff {
/// `inout` parameter types.
///
/// The function type may have at most two parameter lists.
///
/// Remaps the original semantic result using `genericEnv`, if specified.
void getFunctionSemanticResultTypes(
AnyFunctionType *functionType,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv = nullptr);
void getFunctionSemanticResults(
const AnyFunctionType *functionType,
const IndexSubset *parameterIndices,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes);

/// Returns the indices of semantic results for a given function.
IndexSubset *getFunctionSemanticResultIndices(
const AnyFunctionType *functionType,
const IndexSubset *parameterIndices);

IndexSubset *getFunctionSemanticResultIndices(
const AbstractFunctionDecl *AFD,
const IndexSubset *parameterIndices);

/// Returns the lowered SIL parameter indices for the given AST parameter
/// indices and `AnyfunctionType`.
/// indices and `AnyFunctionType`.
///
/// Notable lowering-related changes:
/// - AST tuple parameter types are exploded when lowered to SIL.
Expand Down
3 changes: 0 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3950,9 +3950,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
(DescriptiveDeclKind))
ERROR(autodiff_attr_original_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot differentiate functions with both an 'inout' parameter and a "
"result", ())
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
Expand Down
99 changes: 73 additions & 26 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,39 +176,89 @@ void AnyFunctionType::getSubsetParameters(
}
}

void autodiff::getFunctionSemanticResultTypes(
AnyFunctionType *functionType,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv) {
void autodiff::getFunctionSemanticResults(
const AnyFunctionType *functionType,
const IndexSubset *parameterIndices,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
auto &ctx = functionType->getASTContext();

// Remap type in `genericEnv`, if specified.
auto remap = [&](Type type) {
if (!genericEnv)
return type;
return genericEnv->mapTypeIntoContext(type);
};

// Collect formal result type as a semantic result, unless it is
// `Void`.
auto formalResultType = functionType->getResult();
if (auto *resultFunctionType =
functionType->getResult()->getAs<AnyFunctionType>()) {
functionType->getResult()->getAs<AnyFunctionType>())
formalResultType = resultFunctionType->getResult();

unsigned resultIdx = 0;
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
// Separate tuple elements into individual results.
if (formalResultType->is<TupleType>()) {
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
resultTypes.emplace_back(elt.getType(), resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
}
} else {
resultTypes.emplace_back(formalResultType, resultIdx++,
/*isInout*/ false, /*isWrt*/ false);
}
}
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
result.push_back({remap(formalResultType), /*isInout*/ false});

// Collect `inout` parameters as semantic results.
for (auto param : functionType->getParams())
if (param.isInOut())
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
if (auto *resultFunctionType =
functionType->getResult()->getAs<AnyFunctionType>()) {
for (auto param : resultFunctionType->getParams())
if (param.isInOut())
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
bool addNonWrts = resultTypes.empty();

// Collect wrt `inout` parameters as semantic results
// As an extention, collect all (including non-wrt) inouts as results for
// functions returning void.
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
unsigned curryOffset = 0) {
for (auto paramAndIndex : enumerate(functionType->getParams())) {
if (!paramAndIndex.value().isInOut())
continue;

unsigned idx = paramAndIndex.index() + curryOffset;
assert(idx < parameterIndices->getCapacity() &&
"invalid parameter index");
bool isWrt = parameterIndices->contains(idx);
if (addNonWrts || isWrt)
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
resultIdx, /*isInout*/ true, isWrt);
resultIdx += 1;
}
};

if (auto *resultFnType =
functionType->getResult()->getAs<AnyFunctionType>()) {
// Here we assume that the input is a function type with curried `Self`
assert(functionType->getNumParams() == 1 && "unexpected function type");

collectSemanticResults(resultFnType);
collectSemanticResults(functionType, resultFnType->getNumParams());
} else
collectSemanticResults(functionType);
}

IndexSubset *
autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType,
const IndexSubset *parameterIndices) {
auto &ctx = functionType->getASTContext();

SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
autodiff::getFunctionSemanticResults(functionType, parameterIndices,
semanticResults);
SmallVector<unsigned> resultIndices;
unsigned cap = 0;
for (const auto& result : semanticResults) {
resultIndices.push_back(result.index);
cap = std::max(cap, result.index + 1U);
}

return IndexSubset::get(ctx, cap, resultIndices);
}

IndexSubset *
autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD,
const IndexSubset *parameterIndices) {
return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo<AnyFunctionType>(),
parameterIndices);
}

// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
Expand Down Expand Up @@ -395,9 +445,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
case Kind::NoSemanticResults:
OS << "has no semantic results ('Void' result)";
break;
case Kind::MultipleSemanticResults:
OS << "has multiple semantic results";
break;
case Kind::NoDifferentiabilityParameters:
OS << "has no differentiability parameters";
break;
Expand Down
116 changes: 77 additions & 39 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5548,32 +5548,43 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
getSubsetParameters(parameterIndices, diffParams,
/*reverseCurryLevels*/ !makeSelfParamFirst);

// Get the original semantic result type.
// Get the original non-inout semantic result types.
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
autodiff::getFunctionSemanticResultTypes(this, originalResults);
autodiff::getFunctionSemanticResults(this, parameterIndices, originalResults);
// Error if no original semantic results.
if (originalResults.empty())
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
// Error if multiple original semantic results.
// TODO(TF-1250): Support functions with multiple semantic results.
if (originalResults.size() > 1)
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
auto originalResult = originalResults.front();
auto originalResultType = originalResult.type;

// Get the original semantic result type's `TangentVector` associated type.
auto resultTan =
originalResultType->getAutoDiffTangentSpace(lookupConformance);
// Error if original semantic result has no tangent space.
if (!resultTan) {
return llvm::make_error<DerivativeFunctionTypeError>(

// Accumulate non-inout result tangent spaces.
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
for (auto i : range(originalResults.size())) {
auto originalResult = originalResults[i];
auto originalResultType = originalResult.type;

// Voids currently have a defined tangent vector, so ignore them.
if (originalResultType->isVoid())
continue;

// Get the original semantic result type's `TangentVector` associated type.
// Error if a semantic result has no tangent space.
auto resultTan =
originalResultType->getAutoDiffTangentSpace(lookupConformance);
if (!resultTan)
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
std::make_pair(originalResultType, /*index*/ 0));
std::make_pair(originalResultType, unsigned(originalResult.index)));

if (!originalResult.isInout)
resultTanTypes.push_back(resultTan->getType());
else if (originalResult.isInout && !originalResult.isWrtParam)
inoutTanTypes.push_back(resultTan->getType());
}
auto resultTanType = resultTan->getType();

// Treat non-wrt inouts as semantic results for functions returning Void
if (resultTanTypes.empty())
resultTanTypes = inoutTanTypes;

// Compute the result linear map function type.
FunctionType *linearMapType;
switch (kind) {
Expand All @@ -5586,32 +5597,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
//
// Case 2: original function has a non-wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, ...) -> T1.Tan`
// - Differential: `(T0.Tan, ...) -> T1.Tan`
//
// Case 3: original function has a wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
// - Original: `(T0, inout T1, ...) -> Void`
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
SmallVector<AnyFunctionType::Param, 4> differentialParams;
bool hasInoutDiffParameter = false;
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
// Error if parameter has no tangent space.
if (!paramTan) {
if (!paramTan)
return llvm::make_error<DerivativeFunctionTypeError>(
this,
DerivativeFunctionTypeError::Kind::
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));
}

differentialParams.push_back(AnyFunctionType::Param(
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
if (diffParam.isInOut())
hasInoutDiffParameter = true;
}
auto differentialResult =
hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType;
Type differentialResult;
if (resultTanTypes.empty()) {
differentialResult = ctx.TheEmptyTupleType;
} else if (resultTanTypes.size() == 1) {
differentialResult = resultTanTypes.front();
} else {
SmallVector<TupleTypeElt, 2> differentialResults;
for (auto i : range(resultTanTypes.size())) {
auto resultTanType = resultTanTypes[i];
differentialResults.push_back(
TupleTypeElt(resultTanType, Identifier()));
}
differentialResult = TupleType::get(differentialResults, ctx);
}

// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;
linearMapType =
Expand All @@ -5629,25 +5650,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
// - Original: `(T0, inout T1, ...) -> Void`
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
//
// Case 3: original function has a wrt `inout` parameter.
// - Original: `(T0, inout T1, ...) -> Void`
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
// Case 3: original function has wrt `inout` parameters.
// - Original: `(T0, inout T1, ...) -> R`
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
SmallVector<TupleTypeElt, 4> pullbackResults;
bool hasInoutDiffParameter = false;
SmallVector<AnyFunctionType::Param, 2> inoutParams;
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
// Error if parameter has no tangent space.
if (!paramTan) {
if (!paramTan)
return llvm::make_error<DerivativeFunctionTypeError>(
this,
DerivativeFunctionTypeError::Kind::
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));
}

if (diffParam.isInOut()) {
hasInoutDiffParameter = true;
if (paramType->isVoid())
continue;
inoutParams.push_back(diffParam);
continue;
}
pullbackResults.emplace_back(paramTan->getType());
Expand All @@ -5660,12 +5683,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
} else {
pullbackResult = TupleType::get(pullbackResults, ctx);
}
auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter);
auto pullbackParam =
AnyFunctionType::Param(resultTanType, Identifier(), flags);
// First accumulate non-inout results as pullback parameters.
SmallVector<FunctionType::Param, 2> pullbackParams;
for (auto i : range(resultTanTypes.size())) {
auto resultTanType = resultTanTypes[i];
auto flags = ParameterTypeFlags().withInOut(false);
pullbackParams.push_back(AnyFunctionType::Param(
resultTanType, Identifier(), flags));
}
// Then append inout parameters.
for (auto i : range(inoutParams.size())) {
auto inoutParam = inoutParams[i];
auto inoutParamType = inoutParam.getPlainType();
auto inoutParamTan =
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
auto flags = ParameterTypeFlags().withInOut(true);
pullbackParams.push_back(AnyFunctionType::Param(
inoutParamTan->getType(), Identifier(), flags));
}
// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;
linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info);
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
break;
}
}
Expand Down
Loading