-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] [TF-1288] Supporting differentiable functions with multiple semantic results #38781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
8b6704f
Initial easing of checks for differentiability of inouts with results.
BradLarson 2c7672a
Initial attempt at allowing for multiple results in getAutoDiffDeriva…
BradLarson 47fff24
Attempting to detect non-wrt inout parameters, all inouts are now res…
BradLarson 2d49ad5
Iterate over all result indices in getAutoDiffPullbackType() and getA…
BradLarson 98a8337
Returning all resultIndices from emitDifferentiabilityWitnessesForFun…
BradLarson 6ce7df5
Adding test for cross-module registration of functions with multiple …
BradLarson 2420dbc
Loosening checks that assume only one result. Adjusting tests.
BradLarson f5d4885
Reworking logic for non-wrt inout parameters. Replacing single result…
BradLarson ff8dc58
Consolidated repeated result index generation into a central function.
BradLarson 4997035
Converting a last few areas to use multiple result indices.
BradLarson ffe1e52
Added tuple result tests, extracting tuple elements as semantic resul…
BradLarson 17e6987
Adding @asl's fix for subset parameters thunks involving functions wi…
BradLarson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6432,31 +6432,86 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
getSubsetParameters(parameterIndices, diffParams, | ||
/*reverseCurryLevels*/ !makeSelfParamFirst); | ||
|
||
// Get the original semantic result type. | ||
// Get the original non-inout semantic result types. | ||
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults; | ||
autodiff::getFunctionSemanticResultTypes(this, originalResults); | ||
// Error if no original semantic results. | ||
if (originalResults.empty()) | ||
return llvm::make_error<DerivativeFunctionTypeError>( | ||
this, DerivativeFunctionTypeError::Kind::NoSemanticResults); | ||
// Error if multiple original semantic results. | ||
// TODO(TF-1250): Support functions with multiple semantic results. | ||
if (originalResults.size() > 1) | ||
return llvm::make_error<DerivativeFunctionTypeError>( | ||
this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults); | ||
auto originalResult = originalResults.front(); | ||
auto originalResultType = originalResult.type; | ||
|
||
// Get the original semantic result type's `TangentVector` associated type. | ||
auto resultTan = | ||
originalResultType->getAutoDiffTangentSpace(lookupConformance); | ||
// Error if original semantic result has no tangent space. | ||
if (!resultTan) { | ||
// Accumulate non-inout result tangent spaces. | ||
SmallVector<Type, 1> resultTanTypes; | ||
bool hasInoutResult = false; | ||
for (auto i : range(originalResults.size())) { | ||
auto originalResult = originalResults[i]; | ||
auto originalResultType = originalResult.type; | ||
// Voids currently have a defined tangent vector, so ignore them. | ||
if (originalResultType->isVoid()) | ||
continue; | ||
if (originalResult.isInout) { | ||
hasInoutResult = true; | ||
continue; | ||
} | ||
// Get the original semantic result type's `TangentVector` associated type. | ||
auto resultTan = | ||
originalResultType->getAutoDiffTangentSpace(lookupConformance); | ||
if (!resultTan) | ||
continue; | ||
auto resultTanType = resultTan->getType(); | ||
resultTanTypes.push_back(resultTanType); | ||
} | ||
// Append non-wrt inout result tangent spaces. | ||
// This uses the logic from getSubsetParameters(), only operating over all | ||
// parameter indices and looking for non-wrt indices. | ||
SmallVector<AnyFunctionType *, 2> curryLevels; | ||
// An inlined version of unwrapCurryLevels(). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: Without going too deep into implementation details: we only need to handle one potential "curry level" for method declarations – no need for loops or confusing ad-hoc terminology. |
||
AnyFunctionType *fnTy = this; | ||
while (fnTy != nullptr) { | ||
curryLevels.push_back(fnTy); | ||
fnTy = fnTy->getResult()->getAs<AnyFunctionType>(); | ||
} | ||
|
||
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size()); | ||
unsigned currentOffset = 0; | ||
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) { | ||
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset; | ||
currentOffset += curryLevels[curryLevelIndex]->getNumParams(); | ||
} | ||
|
||
if (!makeSelfParamFirst) { | ||
std::reverse(curryLevels.begin(), curryLevels.end()); | ||
std::reverse(curryLevelParameterIndexOffsets.begin(), | ||
curryLevelParameterIndexOffsets.end()); | ||
} | ||
|
||
for (unsigned curryLevelIndex : indices(curryLevels)) { | ||
auto *curryLevel = curryLevels[curryLevelIndex]; | ||
unsigned parameterIndexOffset = | ||
curryLevelParameterIndexOffsets[curryLevelIndex]; | ||
for (unsigned paramIndex : range(curryLevel->getNumParams())) { | ||
if (parameterIndices->contains(parameterIndexOffset + paramIndex)) | ||
continue; | ||
|
||
auto param = curryLevel->getParams()[paramIndex]; | ||
if (param.isInOut()) { | ||
auto resultType = param.getPlainType(); | ||
if (resultType->isVoid()) | ||
continue; | ||
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance); | ||
if (!resultTan) | ||
continue; | ||
auto resultTanType = resultTan->getType(); | ||
resultTanTypes.push_back(resultTanType); | ||
} | ||
} | ||
} | ||
|
||
// Error if no semantic result has a tangent space. | ||
if (resultTanTypes.empty() && !hasInoutResult) { | ||
return llvm::make_error<DerivativeFunctionTypeError>( | ||
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, | ||
std::make_pair(originalResultType, /*index*/ 0)); | ||
std::make_pair(originalResults.front().type, /*index*/ 0)); | ||
} | ||
auto resultTanType = resultTan->getType(); | ||
|
||
// Compute the result linear map function type. | ||
FunctionType *linearMapType; | ||
|
@@ -6472,11 +6527,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
// - Original: `(T0, inout T1, ...) -> Void` | ||
// - Differential: `(T0.Tan, ...) -> T1.Tan` | ||
// | ||
// Case 3: original function has a wrt `inout` parameter. | ||
// - Original: `(T0, inout T1, ...) -> Void` | ||
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` | ||
// Case 3: original function has wrt `inout` parameters. | ||
// - Original: `(T0, inout T1, ...) -> R` | ||
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> R.Tan` | ||
SmallVector<AnyFunctionType::Param, 4> differentialParams; | ||
bool hasInoutDiffParameter = false; | ||
for (auto i : range(diffParams.size())) { | ||
auto diffParam = diffParams[i]; | ||
auto paramType = diffParam.getPlainType(); | ||
|
@@ -6491,11 +6545,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
} | ||
differentialParams.push_back(AnyFunctionType::Param( | ||
paramTan->getType(), Identifier(), diffParam.getParameterFlags())); | ||
if (diffParam.isInOut()) | ||
hasInoutDiffParameter = true; | ||
} | ||
auto differentialResult = | ||
hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType; | ||
Type differentialResult; | ||
if (resultTanTypes.empty()) { | ||
differentialResult = ctx.TheEmptyTupleType; | ||
} else if (resultTanTypes.size() == 1) { | ||
differentialResult = resultTanTypes.front(); | ||
} else { | ||
SmallVector<TupleTypeElt, 2> differentialResults; | ||
for (auto i : range(resultTanTypes.size())) { | ||
auto resultTanType = resultTanTypes[i]; | ||
differentialResults.push_back( | ||
TupleTypeElt(resultTanType, Identifier())); | ||
} | ||
differentialResult = TupleType::get(differentialResults, ctx); | ||
} | ||
|
||
// FIXME: Verify ExtInfo state is correct, not working by accident. | ||
FunctionType::ExtInfo info; | ||
linearMapType = | ||
|
@@ -6513,11 +6578,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
// - Original: `(T0, inout T1, ...) -> Void` | ||
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)` | ||
// | ||
// Case 3: original function has a wrt `inout` parameter. | ||
// - Original: `(T0, inout T1, ...) -> Void` | ||
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` | ||
// Case 3: original function has wrt `inout` parameters. | ||
// - Original: `(T0, inout T1, ...) -> R` | ||
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` | ||
SmallVector<TupleTypeElt, 4> pullbackResults; | ||
bool hasInoutDiffParameter = false; | ||
SmallVector<AnyFunctionType::Param, 2> inoutParams; | ||
for (auto i : range(diffParams.size())) { | ||
auto diffParam = diffParams[i]; | ||
auto paramType = diffParam.getPlainType(); | ||
|
@@ -6531,7 +6596,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
std::make_pair(paramType, i)); | ||
} | ||
if (diffParam.isInOut()) { | ||
hasInoutDiffParameter = true; | ||
if (paramType->isVoid()) | ||
continue; | ||
inoutParams.push_back(diffParam); | ||
continue; | ||
} | ||
pullbackResults.emplace_back(paramTan->getType()); | ||
|
@@ -6544,12 +6611,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( | |
} else { | ||
pullbackResult = TupleType::get(pullbackResults, ctx); | ||
} | ||
auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter); | ||
auto pullbackParam = | ||
AnyFunctionType::Param(resultTanType, Identifier(), flags); | ||
// First accumulate non-inout results as pullback parameters. | ||
SmallVector<FunctionType::Param, 2> pullbackParams; | ||
for (auto i : range(resultTanTypes.size())) { | ||
auto resultTanType = resultTanTypes[i]; | ||
auto flags = ParameterTypeFlags().withInOut(false); | ||
pullbackParams.push_back(AnyFunctionType::Param( | ||
resultTanType, Identifier(), flags)); | ||
} | ||
// Then append inout parameters. | ||
for (auto i : range(inoutParams.size())) { | ||
auto inoutParam = inoutParams[i]; | ||
auto inoutParamType = inoutParam.getPlainType(); | ||
auto inoutParamTan = | ||
inoutParamType->getAutoDiffTangentSpace(lookupConformance); | ||
auto flags = ParameterTypeFlags().withInOut(true); | ||
pullbackParams.push_back(AnyFunctionType::Param( | ||
inoutParamTan->getType(), Identifier(), flags)); | ||
} | ||
// FIXME: Verify ExtInfo state is correct, not working by accident. | ||
FunctionType::ExtInfo info; | ||
linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info); | ||
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info); | ||
break; | ||
} | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify what this comment means exactly?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm specifically thinking of this test case: https://github.com/apple/swift/blob/72866b6dd9e6024ab97469cafb56333800bbeb67/test/AutoDiff/Sema/derivative_attr_type_checking.swift#L867 involving an
inout Void
.getAutoDiffTangentSpace()
returns an empty tuple for Void: https://github.com/apple/swift/blob/main/lib/AST/Type.cpp#L5359 rather than None, so we can't rely on theif (!resultTan)
check to filter them out in that one case. I didn't want to alter the behavior ofgetAutoDiffTangentSpace()
for this one edge case, so I opted for detecting Voids as a special case.