Skip to content

[AutoDiff] Support differentiation of functions with multiple results in SIL. #32629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/swift/SILOptimizer/Differentiation/PullbackEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
/// A set used to remember local allocations that were destroyed.
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;

/// The seed argument in the pullback function.
SILArgument *seed = nullptr;
/// The seed arguments of the pullback function.
SmallVector<SILArgument *, 4> seeds;

llvm::BumpPtrAllocator allocator;

Expand Down
88 changes: 51 additions & 37 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,12 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
/// Collects the semantic results of the given function type in
/// `originalResults`. The semantic results are formal results followed by
/// `inout` parameters, in type order.
// TODO(TF-983): Generalize to support multiple `inout` parameters. The current
// singular `inoutParam` and `isWrtInoutParameter` are hacky.
static void
getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
Optional<SILParameterInfo> &inoutParam,
bool &isWrtInoutParameter,
IndexSubset *&inoutParameterIndices,
SmallVectorImpl<SILResultInfo> &originalResults) {
inoutParam = None;
isWrtInoutParameter = false;
auto &C = functionType->getASTContext();
SmallVector<unsigned, 4> inoutParamIndices;
// Collect original formal results.
originalResults.append(functionType->getResults().begin(),
functionType->getResults().end());
Expand All @@ -332,11 +329,12 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
auto param = functionType->getParameters()[i];
if (!param.isIndirectInOut())
continue;
inoutParam = param;
isWrtInoutParameter = parameterIndices->contains(i);
inoutParamIndices.push_back(i);
originalResults.push_back(
SILResultInfo(param.getInterfaceType(), ResultConvention::Indirect));
}
inoutParameterIndices =
IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices);
}

/// Returns the differential type for the given original function type,
Expand Down Expand Up @@ -402,11 +400,10 @@ static CanSILFunctionType getAutoDiffDifferentialType(
SmallVector<Type, 4> substReplacements;
SmallVector<ProtocolConformanceRef, 4> substConformances;

Optional<SILParameterInfo> inoutParam = None;
bool isWrtInoutParameter = false;
IndexSubset *inoutParamIndices;
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
isWrtInoutParameter, originalResults);
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);

SmallVector<SILParameterInfo, 4> diffParams;
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
Expand All @@ -430,7 +427,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
}
}
SmallVector<SILResultInfo, 1> differentialResults;
if (!inoutParam || !isWrtInoutParameter) {
if (inoutParamIndices->isEmpty()) {
for (auto resultIndex : resultIndices->getIndices()) {
auto &result = originalResults[resultIndex];
auto resultTan =
Expand Down Expand Up @@ -480,11 +477,10 @@ static CanSILFunctionType getAutoDiffPullbackType(
SmallVector<Type, 4> substReplacements;
SmallVector<ProtocolConformanceRef, 4> substConformances;

Optional<SILParameterInfo> inoutParam = None;
bool isWrtInoutParameter = false;
IndexSubset *inoutParamIndices;
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, parameterIndices, inoutParam,
isWrtInoutParameter, originalResults);
getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices,
originalResults);

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterConventionForOriginalResult =
Expand Down Expand Up @@ -551,27 +547,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
return conv;
};

// Collect pullback parameters.
SmallVector<SILParameterInfo, 1> pullbackParams;
if (inoutParam) {
auto paramTan = inoutParam->getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
auto paramTanConvention = isWrtInoutParameter
? inoutParam->getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
auto paramTanType = paramTan->getCanonicalType();
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackParams.push_back(
SILParameterInfo(paramTanType, paramTanConvention));
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
} else {
for (auto resultIndex : resultIndices->getIndices()) {
for (auto resultIndex : resultIndices->getIndices()) {
// Handle formal original result.
if (resultIndex < originalFnTy->getNumResults()) {
auto &origRes = originalResults[resultIndex];
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
Expand All @@ -590,12 +570,46 @@ static CanSILFunctionType getAutoDiffPullbackType(
substReplacements.push_back(resultTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
continue;
}
// Handle original `inout` parameter.
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
auto inoutParamIt = std::next(
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
auto inoutParam = originalFnTy->getParameters()[paramIndex];
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
// The pullback parameter convention depends on whether the original `inout`
// paramater is a differentiability parameter.
// - If yes, the pullback parameter convention is `@inout`.
// - If no, the pullback parameter convention is `@in_guaranteed`.
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
auto paramTanConvention = isWrtInoutParameter
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
auto paramTanType = paramTan->getCanonicalType();
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
pullbackParams.push_back(
SILParameterInfo(paramTanType, paramTanConvention));
} else {
auto gpIndex = substGenericParams.size();
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
substGenericParams.push_back(gpType);
substReplacements.push_back(paramTanType);
pullbackParams.push_back({gpType, paramTanConvention});
}
}

// Collect pullback results.
SmallVector<SILParameterInfo, 4> diffParams;
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
SmallVector<SILResultInfo, 8> pullbackResults;
for (auto &param : diffParams) {
// Skip `inout` parameters, which semantically behave as original results
// and always appear as pullback parameters.
if (param.isIndirectInOut())
continue;
auto paramTan =
Expand Down
72 changes: 37 additions & 35 deletions lib/SILOptimizer/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,44 +1007,47 @@ bool PullbackEmitter::run() {
}

auto *pullbackEntry = pullback.getEntryBlock();
// The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
// The pullback function has type:
// `(seed0, seed1, ..., exit_pb_struct) -> (d_arg0, ..., d_argn)`.
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
assert(pbParamArgs.size() == 2);
seed = pbParamArgs[0];
// TODO(TF-983): Handle multiple original results.
assert(getIndices().results->getNumIndices() == 1);
auto origResult = origFormalResults[*getIndices().results->begin()];

// Assign adjoint for original result.
assert(getIndices().results->getNumIndices() == pbParamArgs.size() - 1 &&
pbParamArgs.size() >= 2);
// Assign adjoints for original result.
builder.setInsertionPoint(pullbackEntry,
getNextFunctionLocalAllocationInsertionPoint());
if (seed->getType().isAddress()) {
// If the pullback `seed` is an `inout` parameter, assign it directly as the
// adjoint buffer of the original result.
if (pullback.getLoweredFunctionType()
->getParameters()
.front()
.isIndirectInOut()) {
setAdjointBuffer(origExit, origResult, seed);
}
// Otherwise, assign a copy of `seed` as the adjoint buffer of the original
// result.
else {
auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType());
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
IsInitialization);
functionLocalAllocations.push_back(seedBufCopy);
setAdjointBuffer(origExit, origResult, seedBufCopy);
unsigned seedIndex = 0;
for (auto resultIndex : getIndices().results->getIndices()) {
auto origResult = origFormalResults[resultIndex];
auto *seed = pbParamArgs[seedIndex];
if (seed->getType().isAddress()) {
// If the seed argument is an `inout` parameter, assign it directly as
// the adjoint buffer of the original result.
auto seedParamInfo =
pullback.getLoweredFunctionType()->getParameters()[seedIndex];
if (seedParamInfo.isIndirectInOut()) {
setAdjointBuffer(origExit, origResult, seed);
}
// Otherwise, assign a copy of the seed argument as the adjoint buffer of
// the original result.
else {
auto *seedBufCopy =
createFunctionLocalAllocation(seed->getType(), pbLoc);
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
IsInitialization);
setAdjointBuffer(origExit, origResult, seedBufCopy);
LLVM_DEBUG(getADDebugStream()
<< "Assigned seed buffer " << *seedBufCopy
<< " as the adjoint of original indirect result "
<< origResult);
}
} else {
addAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed),
pbLoc);
LLVM_DEBUG(getADDebugStream()
<< "Assigned seed buffer " << seedBufCopy
<< " as the adjoint of original indirect result "
<< origResult);
<< "Assigned seed " << *seed
<< " as the adjoint of original result " << origResult);
}
} else {
setAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed));
LLVM_DEBUG(getADDebugStream()
<< "Assigned seed " << *seed
<< " as the adjoint of original result " << origResult);
++seedIndex;
}

// If the original function is an accessor with special-case pullback
Expand Down Expand Up @@ -1573,8 +1576,7 @@ void PullbackEmitter::visitApplyInst(ApplyInst *ai) {
args.push_back(alloc);
}

// Get formal callee pullback arguments.
assert(applyInfo.indices.results->getNumIndices() == 1);
// Collect callee pullback formal arguments.
for (auto resultIndex : applyInfo.indices.results->getIndices()) {
assert(resultIndex < origAllResults.size());
auto origResult = origAllResults[resultIndex];
Expand Down
66 changes: 36 additions & 30 deletions lib/SILOptimizer/Differentiation/VJPEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,18 @@ SILFunction *VJPEmitter::createEmptyPullback() {
auto origParams = origTy->getParameters();
auto indices = witness->getSILAutoDiffIndices();

// Add pullback parameter for the seed.
Optional<SILParameterInfo> inoutParam;
bool isWrtInoutParam = false;
// Add pullback parameters based on original result indices.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the pullback type calculation logic here in VJPEmitter::createEmptyPullback() duplicates SILFunctionType::getAutoDiffDerivativeFunctionType. It should be deduped, I'll look into it soon.

SmallVector<unsigned, 4> inoutParamIndices;
for (auto i : range(origTy->getNumParameters())) {
auto origParam = origParams[i];
if (!origParam.isIndirectInOut())
continue;
isWrtInoutParam = indices.parameters->contains(i);
inoutParam = origParam;
inoutParamIndices.push_back(i);
}
if (inoutParam) {
auto origResult = inoutParam->getWithInterfaceType(
inoutParam->getInterfaceType()->getCanonicalType(witnessCanGenSig));
auto inoutParamTanConvention =
isWrtInoutParam ? inoutParam->getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
SILParameterInfo inoutParamTanParam(
origResult.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getType()
->getCanonicalType(witnessCanGenSig),
inoutParamTanConvention);
pbParams.push_back(inoutParamTanParam);
} else {
for (auto i : indices.results->getIndices()) {
auto origResult = origTy->getResults()[i];
for (auto resultIndex : indices.results->getIndices()) {
// Handle formal result.
if (resultIndex < origTy->getNumResults()) {
auto origResult = origTy->getResults()[resultIndex];
origResult = origResult.getWithInterfaceType(
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
pbParams.push_back(getTangentParameterInfoForOriginalResult(
Expand All @@ -189,7 +175,36 @@ SILFunction *VJPEmitter::createEmptyPullback() {
->getType()
->getCanonicalType(witnessCanGenSig),
origResult.getConvention()));
continue;
}
// Handle `inout` parameter.
unsigned paramIndex = 0;
unsigned inoutParamIndex = 0;
for (auto i : range(origTy->getNumParameters())) {
auto origParam = origTy->getParameters()[i];
if (!origParam.isIndirectMutating()) {
++paramIndex;
continue;
}
if (inoutParamIndex == resultIndex - origTy->getNumResults())
break;
++paramIndex;
++inoutParamIndex;
}
auto inoutParam = origParams[paramIndex];
auto origResult = inoutParam.getWithInterfaceType(
inoutParam.getInterfaceType()->getCanonicalType(witnessCanGenSig));
auto inoutParamTanConvention =
indices.isWrtParameter(paramIndex)
? inoutParam.getConvention()
: ParameterConvention::Indirect_In_Guaranteed;
SILParameterInfo inoutParamTanParam(
origResult.getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getType()
->getCanonicalType(witnessCanGenSig),
inoutParamTanConvention);
pbParams.push_back(inoutParamTanParam);
}

// Accept a pullback struct in the pullback parameter list. This is the
Expand Down Expand Up @@ -587,15 +602,6 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) {
activeResultIndices.begin(), activeResultIndices.end(),
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
s << "}\n";);
// Diagnose multiple active results.
// TODO(TF-983): Support multiple active results.
if (activeResultIndices.size() > 1) {
context.emitNondifferentiabilityError(
ai, invoker,
diag::autodiff_cannot_differentiate_through_multiple_results);
errorOccurred = true;
return;
}

// Form expected indices.
auto numSemanticResults =
Expand Down
12 changes: 0 additions & 12 deletions test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,8 @@ func multipleResults(_ x: Float) -> (Float, Float) {
return (x, x)
}

// TODO(TF-983): Support differentiation of multiple results.
// expected-error @+2 {{function is not differentiable}}
// expected-note @+2 {{when differentiating this function definition}}
@differentiable
func usesMultipleResults(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through multiple results}}
let tuple = multipleResults(x)
return tuple.0 + tuple.1
}
Expand Down Expand Up @@ -440,27 +436,19 @@ func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) {
nonactive = result.0
}

// TODO(TF-983): Support differentiation of multiple results.
func twoInoutParameters(_ x: inout Float, _ y: inout Float) {}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+2 {{when differentiating this function definition}}
@differentiable
func testTwoInoutParameters(_ x: Float, _ y: Float) -> Float {
var x = x
var y = y
// expected-note @+1 {{cannot differentiate through multiple results}}
twoInoutParameters(&x, &y)
return x
}

// TODO(TF-983): Support differentiation of multiple results.
func inoutParameterAndFormalResult(_ x: inout Float) -> Float { x }
// expected-error @+2 {{function is not differentiable}}
// expected-note @+2 {{when differentiating this function definition}}
@differentiable
func testInoutParameterAndFormalResult(_ x: Float) -> Float {
var x = x
// expected-note @+1 {{cannot differentiate through multiple results}}
return inoutParameterAndFormalResult(&x)
}

Expand Down
Loading