Skip to content

[AutoDiff] inout parameter differentiation mega-patch. #29956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ struct AutoDiffConfig {
SWIFT_DEBUG_DUMP;
};

/// A semantic function result type: either a formal function result type or
/// an `inout` parameter. Used in derivative function type calculation.
struct SemanticFunctionResultType {
Type type;
bool isInout;
};

class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
Expand Down Expand Up @@ -240,11 +247,20 @@ using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
/// Automatic differentiation utility namespace.
namespace autodiff {

/// Appends the subset's parameter's types to `results`, in the order in
/// which they appear in the function type.
void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
SmallVectorImpl<Type> &results,
bool reverseCurryLevels = false);
/// Given an original/derivative function type and the original formal result
/// type, return the original semantic result type: either the original formal
/// result type or an `inout` parameter.
///
/// The original/derivative function type may have at most two parameter lists.
///
/// Sets `hasMultipleOriginalSemanticResults` to true iff there are multiple
/// semantic results.
///
/// Remap the original semantic result using `derivativeGenEnv`, if specified.
SemanticFunctionResultType getOriginalFunctionSemanticResultType(
AnyFunctionType *functionType, Type originalFormalResultType,
bool &hasMultipleOriginalSemanticResults,
GenericEnvironment *derivativeGenEnv = nullptr);

/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
Expand Down
10 changes: 8 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2974,8 +2974,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
(DeclName, DeclName))

// @differentiable
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
"attribute for transpose registration instead", ())
Expand All @@ -3000,6 +2998,9 @@ ERROR(differentiable_attr_invalid_access,none,
ERROR(differentiable_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
ERROR(differentiable_attr_multiple_original_results,none,
"cannot yet differentiate functions with more than one semantic result "
"(formal function result or 'inout' parameter)", ())
ERROR(differentiable_attr_protocol_req_where_clause,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'where' clause", ())
Expand Down Expand Up @@ -3093,6 +3094,11 @@ ERROR(autodiff_attr_original_decl_none_valid_found,none,
"could not find function %0 with expected type %1", (DeclNameRef, Type))
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
"%0 is not defined in the current type context", (DeclNameRef))
ERROR(autodiff_attr_original_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot yet differentiate functions with more than one semantic result "
"(formal function result or 'inout' parameter)", ())

// differentiation `wrt` parameters clause
ERROR(diff_function_no_parameters,none,
Expand Down
4 changes: 3 additions & 1 deletion include/swift/AST/IndexSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode {
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
ArrayRef<unsigned> indices) {
SmallBitVector indicesBitVec(capacity, false);
for (auto index : indices)
for (auto index : indices) {
assert(index < capacity);
indicesBitVec.set(index);
}
return IndexSubset::get(ctx, indicesBitVec);
}

Expand Down
98 changes: 94 additions & 4 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3213,15 +3213,25 @@ class AnyFunctionType : public TypeBase {
return getExtInfo().getRepresentation();
}

/// Appends the parameters indicated by `parameterIndices` to `results`.
///
/// For curried function types: if `reverseCurryLevels` is true, append
/// the `self` parameter last instead of first.
///
/// TODO(TF-874): Simplify logic and remove the `reverseCurryLevels` flag.
void getSubsetParameters(IndexSubset *parameterIndices,
SmallVectorImpl<AnyFunctionType::Param> &results,
bool reverseCurryLevels = false);

/// Returns the derivative function type for the given parameter indices,
/// result index, derivative function kind, derivative function generic
/// signature (optional), and other auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - The result corresponding to the result index must conform to
/// `Differentiable`.
/// - There is one semantic function result type: either the formal original
/// result or an `inout` parameter. It must conform to `Differentiable`.
///
/// Typing rules, given:
/// - Original function type. Three cases:
Expand Down Expand Up @@ -3267,6 +3277,11 @@ class AnyFunctionType : public TypeBase {
/// original result | deriv. wrt result | deriv. wrt params
/// \endverbatim
///
/// The original type may have `inout` parameters. If so, the
/// differential/pullback typing rules are more nuanced: see documentation for
/// `getAutoDiffReturnedLinearMapFunctionType` for details. Semantically,
/// `inout` parameters behave as both parameters and results.
///
/// By default, if the original type has a `self` parameter list and parameter
/// indices include `self`, the computed derivative function type will return
/// a linear map taking/returning self's tangent *last* instead of first, for
Expand All @@ -3277,14 +3292,58 @@ class AnyFunctionType : public TypeBase {
/// derivative function types, e.g. when type-checking `@differentiable` and
/// `@derivative` attributes.
AnyFunctionType *getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind,
IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature derivativeGenericSignature = GenericSignature(),
bool makeSelfParamFirst = false);

/// Returns the linear map function type returned by the derivative function
/// type for the given parameter indices, linear map function kind, and other
/// auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - There is one semantic function result type: either the formal original
/// result or an `inout` parameter. It must conform to `Differentiable`.
///
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
/// "wrt" result derivative.
///
/// - Case 1: no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> R`
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
///
/// Pullback typing rules: takes a "wrt" result derivative and returns "wrt"
/// parameter derivatives.
///
/// - Case 1: original function has no `inout` parameters.
/// - Original: `(T0, T1, ...) -> R`
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
/// - Case 2: original function has a non-wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> Void`
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
/// - Case 3: original function has a wrt `inout` parameter.
/// - Original: `(T0, inout T1, ...) -> R`
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
///
/// If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
/// first. `makeSelfParamFirst` should be true when working with user-facing
/// derivative function types, e.g. when type-checking `@differentiable` and
/// `@derivative` attributes.
AnyFunctionType *getAutoDiffReturnedLinearMapFunctionType(
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false);

// SWIFT_ENABLE_TENSORFLOW
AnyFunctionType *getWithoutDifferentiability() const;
// SWIFT_ENABLE_TENSORFLOW END

/// True if the parameter declaration it is attached to is guaranteed
/// to not persist the closure for longer than the duration of the call.
Expand Down Expand Up @@ -4419,6 +4478,28 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
return getParameters().back();
}

struct IndirectMutatingParameterFilter {
bool operator()(SILParameterInfo param) const {
return param.isIndirectMutating();
}
};
using IndirectMutatingParameterIter =
llvm::filter_iterator<const SILParameterInfo *,
IndirectMutatingParameterFilter>;
using IndirectMutatingParameterRange =
iterator_range<IndirectMutatingParameterIter>;

/// A range of SILParameterInfo for all indirect mutating parameters.
IndirectMutatingParameterRange getIndirectMutatingParameters() const {
return llvm::make_filter_range(getParameters(),
IndirectMutatingParameterFilter());
}

/// Returns the number of indirect mutating parameters.
unsigned getNumIndirectMutatingParameters() const {
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
}

/// Get the generic signature used to apply the substitutions of a substituted function type
CanGenericSignature getSubstGenericSignature() const {
return GenericSigAndIsImplied.getPointer();
Expand Down Expand Up @@ -4487,18 +4568,27 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
/// - Returns original results, followed by a differential function, which
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
///
/// \verbatim
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
/// original results | derivatives wrt params | derivative wrt result
/// \endverbatim
///
/// VJP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a pullback function, which
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
///
/// \verbatim
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
/// original results | derivative wrt result | derivatives wrt params
/// \endverbatim
///
/// The original type may have `inout` parameters. If so, the
/// differential/pullback typing rules are more nuanced: see documentation for
/// `getAutoDiffReturnedLinearMapFunctionType` for details. Semantically,
/// `inout` parameters behave as both parameters and results.
///
/// A "constrained derivative generic signature" is computed from
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
Expand Down
60 changes: 52 additions & 8 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "swift/AST/ASTContext.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/Module.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"

Expand Down Expand Up @@ -50,13 +52,11 @@ static unsigned countNumFlattenedElementTypes(Type type) {
}

// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
// See TF-874 for WIP.
void autodiff::getSubsetParameterTypes(IndexSubset *subset,
AnyFunctionType *type,
SmallVectorImpl<Type> &results,
bool reverseCurryLevels) {
void AnyFunctionType::getSubsetParameters(
IndexSubset *parameterIndices,
SmallVectorImpl<AnyFunctionType::Param> &results, bool reverseCurryLevels) {
SmallVector<AnyFunctionType *, 2> curryLevels;
unwrapCurryLevels(type, curryLevels);
unwrapCurryLevels(this, curryLevels);

SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
unsigned currentOffset = 0;
Expand All @@ -77,11 +77,55 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
unsigned parameterIndexOffset =
curryLevelParameterIndexOffsets[curryLevelIndex];
for (unsigned paramIndex : range(curryLevel->getNumParams()))
if (subset->contains(parameterIndexOffset + paramIndex))
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
results.push_back(curryLevel->getParams()[paramIndex]);
}
}

SemanticFunctionResultType autodiff::getOriginalFunctionSemanticResultType(
AnyFunctionType *functionType, Type originalFormalResultType,
bool &hasMultipleOriginalSemanticResults,
GenericEnvironment *derivativeGenEnv) {
auto &ctx = functionType->getASTContext();
hasMultipleOriginalSemanticResults = false;

// Initialize original semantic result type as the original formal result
// type, unless it is `Void`.
Type originalResultType;
bool isOriginalResultInout = false;
if (!originalFormalResultType->isEqual(ctx.TheEmptyTupleType))
originalResultType = originalFormalResultType;

auto setOriginalSemanticResultType = [&](Type type) {
// If original semantic result type has already been set, unset it and set
// `hasMultipleOriginalSemanticResults` to true.
if (originalResultType) {
hasMultipleOriginalSemanticResults = true;
originalResultType = Type();
return;
}
originalResultType = type;
isOriginalResultInout = true;
};

// Check for `inout` parameters.
for (auto param : functionType->getParams())
if (param.isInOut())
setOriginalSemanticResultType(param.getPlainType());
if (auto *resultFunctionType =
functionType->getResult()->getAs<AnyFunctionType>()) {
for (auto param : resultFunctionType->getParams())
if (param.isInOut())
setOriginalSemanticResultType(param.getPlainType());
}

// Map original semantic result type into derivative generic environment.
if (originalResultType && derivativeGenEnv)
originalResultType =
derivativeGenEnv->mapTypeIntoContext(originalResultType);
return {originalResultType, isOriginalResultInout};
}

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
BuiltinFunctionBuilder::LambdaGenerator resultGen{
[=, &Context](BuiltinFunctionBuilder &builder) -> Type {
auto derivativeFnTy = diffFnType->getAutoDiffDerivativeFunctionType(
paramIndices, /*resultIndex*/ 0, kind,
paramIndices, kind,
LookUpConformanceInModule(Context.TheBuiltinModule));
return derivativeFnTy->getResult();
}};
Expand Down
Loading