diff --git a/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h b/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h index fb66d4b79b885..ccdfb99768de1 100644 --- a/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h @@ -105,8 +105,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> { /// A set used to remember local allocations that were destroyed. llvm::SmallDenseSet<SILValue> destroyedLocalAllocations; - /// The seed argument in the pullback function. - SILArgument *seed = nullptr; + /// The seed arguments of the pullback function. + SmallVector<SILArgument *, 4> seeds; llvm::BumpPtrAllocator allocator; diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 79659d8dcdb1a..0091d35779c08 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -315,15 +315,12 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy, /// Collects the semantic results of the given function type in /// `originalResults`. The semantic results are formal results followed by /// `inout` parameters, in type order. -// TODO(TF-983): Generalize to support multiple `inout` parameters. The current -// singular `inoutParam` and `isWrtInoutParameter` are hacky. static void getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, - Optional<SILParameterInfo> &inoutParam, - bool &isWrtInoutParameter, + IndexSubset *&inoutParameterIndices, SmallVectorImpl<SILResultInfo> &originalResults) { - inoutParam = None; - isWrtInoutParameter = false; + auto &C = functionType->getASTContext(); + SmallVector<unsigned, 4> inoutParamIndices; // Collect original formal results. originalResults.append(functionType->getResults().begin(), functionType->getResults().end()); @@ -332,11 +329,12 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, auto param = functionType->getParameters()[i]; if (!param.isIndirectInOut()) continue; - inoutParam = param; - isWrtInoutParameter = parameterIndices->contains(i); + inoutParamIndices.push_back(i); originalResults.push_back( SILResultInfo(param.getInterfaceType(), ResultConvention::Indirect)); } + inoutParameterIndices = + IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices); } /// Returns the differential type for the given original function type, @@ -402,11 +400,10 @@ static CanSILFunctionType getAutoDiffDifferentialType( SmallVector<Type, 4> substReplacements; SmallVector<ProtocolConformanceRef, 4> substConformances; - Optional<SILParameterInfo> inoutParam = None; - bool isWrtInoutParameter = false; + IndexSubset *inoutParamIndices; SmallVector<SILResultInfo, 2> originalResults; - getSemanticResults(originalFnTy, parameterIndices, inoutParam, - isWrtInoutParameter, originalResults); + getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices, + originalResults); SmallVector<SILParameterInfo, 4> diffParams; getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); @@ -430,7 +427,7 @@ static CanSILFunctionType getAutoDiffDifferentialType( } } SmallVector<SILResultInfo, 1> differentialResults; - if (!inoutParam || !isWrtInoutParameter) { + if (inoutParamIndices->isEmpty()) { for (auto resultIndex : resultIndices->getIndices()) { auto &result = originalResults[resultIndex]; auto resultTan = @@ -480,11 +477,10 @@ static CanSILFunctionType getAutoDiffPullbackType( SmallVector<Type, 4> substReplacements; SmallVector<ProtocolConformanceRef, 4> substConformances; - Optional<SILParameterInfo> inoutParam = None; - bool isWrtInoutParameter = false; + IndexSubset *inoutParamIndices; SmallVector<SILResultInfo, 2> originalResults; - getSemanticResults(originalFnTy, parameterIndices, inoutParam, - isWrtInoutParameter, originalResults); + getSemanticResults(originalFnTy, parameterIndices, inoutParamIndices, + originalResults); // Given a type, returns its formal SIL parameter info. auto getTangentParameterConventionForOriginalResult = @@ -551,27 +547,11 @@ static CanSILFunctionType getAutoDiffPullbackType( return conv; }; + // Collect pullback parameters. SmallVector<SILParameterInfo, 1> pullbackParams; - if (inoutParam) { - auto paramTan = inoutParam->getInterfaceType()->getAutoDiffTangentSpace( - lookupConformance); - assert(paramTan && "Parameter type does not have a tangent space?"); - auto paramTanConvention = isWrtInoutParameter - ? inoutParam->getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - auto paramTanType = paramTan->getCanonicalType(); - if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) { - pullbackParams.push_back( - SILParameterInfo(paramTanType, paramTanConvention)); - } else { - auto gpIndex = substGenericParams.size(); - auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx); - substGenericParams.push_back(gpType); - substReplacements.push_back(paramTanType); - pullbackParams.push_back({gpType, paramTanConvention}); - } - } else { - for (auto resultIndex : resultIndices->getIndices()) { + for (auto resultIndex : resultIndices->getIndices()) { + // Handle formal original result. + if (resultIndex < originalFnTy->getNumResults()) { auto &origRes = originalResults[resultIndex]; auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace( lookupConformance); @@ -590,12 +570,46 @@ static CanSILFunctionType getAutoDiffPullbackType( substReplacements.push_back(resultTanType); pullbackParams.push_back({gpType, paramTanConvention}); } + continue; + } + // Handle original `inout` parameter. + auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); + auto inoutParamIt = std::next( + originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); + auto paramIndex = + std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); + auto inoutParam = originalFnTy->getParameters()[paramIndex]; + auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace( + lookupConformance); + assert(paramTan && "Parameter type does not have a tangent space?"); + // The pullback parameter convention depends on whether the original `inout` + // paramater is a differentiability parameter. + // - If yes, the pullback parameter convention is `@inout`. + // - If no, the pullback parameter convention is `@in_guaranteed`. + bool isWrtInoutParameter = parameterIndices->contains(paramIndex); + auto paramTanConvention = isWrtInoutParameter + ? inoutParam.getConvention() + : ParameterConvention::Indirect_In_Guaranteed; + auto paramTanType = paramTan->getCanonicalType(); + if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) { + pullbackParams.push_back( + SILParameterInfo(paramTanType, paramTanConvention)); + } else { + auto gpIndex = substGenericParams.size(); + auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx); + substGenericParams.push_back(gpType); + substReplacements.push_back(paramTanType); + pullbackParams.push_back({gpType, paramTanConvention}); } } + + // Collect pullback results. SmallVector<SILParameterInfo, 4> diffParams; getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); SmallVector<SILResultInfo, 8> pullbackResults; for (auto ¶m : diffParams) { + // Skip `inout` parameters, which semantically behave as original results + // and always appear as pullback parameters. if (param.isIndirectInOut()) continue; auto paramTan = diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp index 48e937ec63176..a5fae065e4e67 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp @@ -1007,44 +1007,47 @@ bool PullbackEmitter::run() { } auto *pullbackEntry = pullback.getEntryBlock(); - // The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]). + // The pullback function has type: + // `(seed0, seed1, ..., exit_pb_struct) -> (d_arg0, ..., d_argn)`. auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); - assert(pbParamArgs.size() == 2); - seed = pbParamArgs[0]; - // TODO(TF-983): Handle multiple original results. - assert(getIndices().results->getNumIndices() == 1); - auto origResult = origFormalResults[*getIndices().results->begin()]; - - // Assign adjoint for original result. + assert(getIndices().results->getNumIndices() == pbParamArgs.size() - 1 && + pbParamArgs.size() >= 2); + // Assign adjoints for original result. builder.setInsertionPoint(pullbackEntry, getNextFunctionLocalAllocationInsertionPoint()); - if (seed->getType().isAddress()) { - // If the pullback `seed` is an `inout` parameter, assign it directly as the - // adjoint buffer of the original result. - if (pullback.getLoweredFunctionType() - ->getParameters() - .front() - .isIndirectInOut()) { - setAdjointBuffer(origExit, origResult, seed); - } - // Otherwise, assign a copy of `seed` as the adjoint buffer of the original - // result. - else { - auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType()); - builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, - IsInitialization); - functionLocalAllocations.push_back(seedBufCopy); - setAdjointBuffer(origExit, origResult, seedBufCopy); + unsigned seedIndex = 0; + for (auto resultIndex : getIndices().results->getIndices()) { + auto origResult = origFormalResults[resultIndex]; + auto *seed = pbParamArgs[seedIndex]; + if (seed->getType().isAddress()) { + // If the seed argument is an `inout` parameter, assign it directly as + // the adjoint buffer of the original result. + auto seedParamInfo = + pullback.getLoweredFunctionType()->getParameters()[seedIndex]; + if (seedParamInfo.isIndirectInOut()) { + setAdjointBuffer(origExit, origResult, seed); + } + // Otherwise, assign a copy of the seed argument as the adjoint buffer of + // the original result. + else { + auto *seedBufCopy = + createFunctionLocalAllocation(seed->getType(), pbLoc); + builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, + IsInitialization); + setAdjointBuffer(origExit, origResult, seedBufCopy); + LLVM_DEBUG(getADDebugStream() + << "Assigned seed buffer " << *seedBufCopy + << " as the adjoint of original indirect result " + << origResult); + } + } else { + addAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed), + pbLoc); LLVM_DEBUG(getADDebugStream() - << "Assigned seed buffer " << seedBufCopy - << " as the adjoint of original indirect result " - << origResult); + << "Assigned seed " << *seed + << " as the adjoint of original result " << origResult); } - } else { - setAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed)); - LLVM_DEBUG(getADDebugStream() - << "Assigned seed " << *seed - << " as the adjoint of original result " << origResult); + ++seedIndex; } // If the original function is an accessor with special-case pullback @@ -1573,8 +1576,7 @@ void PullbackEmitter::visitApplyInst(ApplyInst *ai) { args.push_back(alloc); } - // Get formal callee pullback arguments. - assert(applyInfo.indices.results->getNumIndices() == 1); + // Collect callee pullback formal arguments. for (auto resultIndex : applyInfo.indices.results->getIndices()) { assert(resultIndex < origAllResults.size()); auto origResult = origAllResults[resultIndex]; diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index a5e2b80348dae..c9eab1f3b39e7 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -155,32 +155,18 @@ SILFunction *VJPEmitter::createEmptyPullback() { auto origParams = origTy->getParameters(); auto indices = witness->getSILAutoDiffIndices(); - // Add pullback parameter for the seed. - Optional<SILParameterInfo> inoutParam; - bool isWrtInoutParam = false; + // Add pullback parameters based on original result indices. + SmallVector<unsigned, 4> inoutParamIndices; for (auto i : range(origTy->getNumParameters())) { auto origParam = origParams[i]; if (!origParam.isIndirectInOut()) continue; - isWrtInoutParam = indices.parameters->contains(i); - inoutParam = origParam; + inoutParamIndices.push_back(i); } - if (inoutParam) { - auto origResult = inoutParam->getWithInterfaceType( - inoutParam->getInterfaceType()->getCanonicalType(witnessCanGenSig)); - auto inoutParamTanConvention = - isWrtInoutParam ? inoutParam->getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - SILParameterInfo inoutParamTanParam( - origResult.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getCanonicalType(witnessCanGenSig), - inoutParamTanConvention); - pbParams.push_back(inoutParamTanParam); - } else { - for (auto i : indices.results->getIndices()) { - auto origResult = origTy->getResults()[i]; + for (auto resultIndex : indices.results->getIndices()) { + // Handle formal result. + if (resultIndex < origTy->getNumResults()) { + auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig)); pbParams.push_back(getTangentParameterInfoForOriginalResult( @@ -189,7 +175,36 @@ SILFunction *VJPEmitter::createEmptyPullback() { ->getType() ->getCanonicalType(witnessCanGenSig), origResult.getConvention())); + continue; + } + // Handle `inout` parameter. + unsigned paramIndex = 0; + unsigned inoutParamIndex = 0; + for (auto i : range(origTy->getNumParameters())) { + auto origParam = origTy->getParameters()[i]; + if (!origParam.isIndirectMutating()) { + ++paramIndex; + continue; + } + if (inoutParamIndex == resultIndex - origTy->getNumResults()) + break; + ++paramIndex; + ++inoutParamIndex; } + auto inoutParam = origParams[paramIndex]; + auto origResult = inoutParam.getWithInterfaceType( + inoutParam.getInterfaceType()->getCanonicalType(witnessCanGenSig)); + auto inoutParamTanConvention = + indices.isWrtParameter(paramIndex) + ? inoutParam.getConvention() + : ParameterConvention::Indirect_In_Guaranteed; + SILParameterInfo inoutParamTanParam( + origResult.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getCanonicalType(witnessCanGenSig), + inoutParamTanConvention); + pbParams.push_back(inoutParamTanParam); } // Accept a pullback struct in the pullback parameter list. This is the @@ -587,15 +602,6 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) { activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << "}\n";); - // Diagnose multiple active results. - // TODO(TF-983): Support multiple active results. - if (activeResultIndices.size() > 1) { - context.emitNondifferentiabilityError( - ai, invoker, - diag::autodiff_cannot_differentiate_through_multiple_results); - errorOccurred = true; - return; - } // Form expected indices. auto numSemanticResults = diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index a038e01dafbf2..2398f39ad1d5f 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -346,12 +346,8 @@ func multipleResults(_ x: Float) -> (Float, Float) { return (x, x) } -// TODO(TF-983): Support differentiation of multiple results. -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable func usesMultipleResults(_ x: Float) -> Float { - // expected-note @+1 {{cannot differentiate through multiple results}} let tuple = multipleResults(x) return tuple.0 + tuple.1 } @@ -440,27 +436,19 @@ func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { nonactive = result.0 } -// TODO(TF-983): Support differentiation of multiple results. func twoInoutParameters(_ x: inout Float, _ y: inout Float) {} -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable func testTwoInoutParameters(_ x: Float, _ y: Float) -> Float { var x = x var y = y - // expected-note @+1 {{cannot differentiate through multiple results}} twoInoutParameters(&x, &y) return x } -// TODO(TF-983): Support differentiation of multiple results. func inoutParameterAndFormalResult(_ x: inout Float) -> Float { x } -// expected-error @+2 {{function is not differentiable}} -// expected-note @+2 {{when differentiating this function definition}} @differentiable func testInoutParameterAndFormalResult(_ x: Float) -> Float { var x = x - // expected-note @+1 {{cannot differentiate through multiple results}} return inoutParameterAndFormalResult(&x) } diff --git a/test/AutoDiff/validation-test/forward_mode.swift b/test/AutoDiff/validation-test/forward_mode.swift index 14bd1e38d00e7..22b049158c432 100644 --- a/test/AutoDiff/validation-test/forward_mode.swift +++ b/test/AutoDiff/validation-test/forward_mode.swift @@ -746,7 +746,7 @@ ForwardModeTests.test("SimpleWrtSelf") { // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial. var _nontrivial: [Float] = [] - // FIXME(SR-12175): Fix forward-mode differentiation crash. + // FIXME(SR-12175): Fix forward-mode differentiation tangent buffer crash. // @differentiable required init(base: Float) { self.base = base @@ -792,7 +792,7 @@ ForwardModeTests.test("SimpleWrtSelf") { } } - // FIXME(SR-12175): Fix forward-mode differentiation crash. + // FIXME(SR-12175): Fix forward-mode differentiation tangent buffer crash. // let v = Super.TangentVector(base: 100, _nontrivial: []) // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) @@ -1064,12 +1064,80 @@ ForwardModeTests.test("FunctionCall") { } ForwardModeTests.test("ResultSelection") { - func foo(_ x: Float, _ y: Float) -> (Float, Float) { + func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) } - expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).0 })) - expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).1 })) + expectEqual(1, derivative(at: 3, 3, in: { x, y in tuple(x, y).0 })) + expectEqual(1, derivative(at: 3, 3, in: { x, y in tuple(x, y).1 })) + + // FIXME(SR-12175): Fix forward-mode differentiation tangent buffer crash. + /* + func tupleGeneric<T>(_ x: T, _ y: T) -> (T, T) { + return (x, y) + } + func tupleGenericFirst<T>(_ x: T, _ y: T) -> T { tupleGeneric(x, y).0 } + func tupleGenericSecond<T>(_ x: T, _ y: T) -> T { tupleGeneric(x, y).1 } + expectEqual(1, derivative(at: 3, 3, in: tupleGenericFirst)) + expectEqual(1, derivative(at: 3, 3, in: tupleGenericSecond)) + */ +} + +// TODO(TF-983): Support forward-mode differentiation of multiple results. +/* +ForwardModeTests.test("MultipleResults") { + // Test function returning a tuple of active results. + func tuple(_ x: Float, _ y: Float) -> (Float, Float) { + return (x, y) + } + func multiply(_ x: Float, _ y: Float) -> Float { + let z = tuple(x, y) + // Note: both results (tuple elements) are active. + return z.0 * z.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply)) + + // Test function with multiple `inout` parameters. + func swap(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp + } + func multiply_swap(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swap(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swap)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply_swap)) + + // Test function with multiple `inout` parameters. + func swapGeneric<T>(_ x: inout T, _ y: inout T) { + let tmp = x; x = y; y = tmp + } + func multiply_swapGeneric(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapGeneric(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapGeneric)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply_swapGeneric)) + + // Test function with multiple `inout` parameters and a formal result. + func swapAndReturnProduct(_ x: inout Float, _ y: inout Float) -> Float { + let tmp = x + x = y + y = tmp + return x * y + } + func multiply_swapAndReturnProduct(_ x: Float, _ y: Float) -> Float { + var x2 = x + var y2 = y + let result = swapAndReturnProduct(&x2, &y2) + return result + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapAndReturnProduct)) + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapAndReturnProduct)) } +*/ ForwardModeTests.test("CaptureLocal") { let z: Float = 10 diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index 68c4ce442d550..d96332c5f4941 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -1,7 +1,11 @@ // RUN: %target-run-simple-swift + // NOTE(TF-813): verify that enabling forward-mode does not affect reverse-mode. -// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation) +// Temporarily disabled because forward-mode is not at feature parity with reverse-mode. +// UN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation) + // RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null -module-name null 2>&1 | %FileCheck %s + // REQUIRES: executable_test import StdlibUnittest @@ -48,11 +52,73 @@ SimpleMathTests.test("FunctionCall") { } SimpleMathTests.test("ResultSelection") { - func foo(_ x: Float, _ y: Float) -> (Float, Float) { + func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) } - expectEqual((1, 0), gradient(at: 3, 3, in: { x, y in foo(x, y).0 })) - expectEqual((0, 1), gradient(at: 3, 3, in: { x, y in foo(x, y).1 })) + expectEqual((1, 0), gradient(at: 3, 3, in: { x, y in tuple(x, y).0 })) + expectEqual((0, 1), gradient(at: 3, 3, in: { x, y in tuple(x, y).1 })) + + func tupleGeneric<T>(_ x: T, _ y: T) -> (T, T) { + return (x, y) + } + func tupleGenericFirst<T>(_ x: T, _ y: T) -> T { tupleGeneric(x, y).0 } + func tupleGenericSecond<T>(_ x: T, _ y: T) -> T { tupleGeneric(x, y).1 } + expectEqual((1, 0), gradient(at: 3, 3, in: tupleGenericFirst)) + expectEqual((0, 1), gradient(at: 3, 3, in: tupleGenericSecond)) +} + +SimpleMathTests.test("MultipleResults") { + // Test function returning a tuple of active results. + func tuple(_ x: Float, _ y: Float) -> (Float, Float) { + return (x, y) + } + func multiply(_ x: Float, _ y: Float) -> Float { + let z = tuple(x, y) + // Note: both results (tuple elements) are active. + return z.0 * z.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply)) + + // Test function with multiple `inout` parameters. + func swap(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp + } + func multiply_swap(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swap(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swap)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply_swap)) + + // Test function with multiple `inout` parameters. + func swapGeneric<T>(_ x: inout T, _ y: inout T) { + let tmp = x; x = y; y = tmp + } + func multiply_swapGeneric(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapGeneric(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapGeneric)) + expectEqual((10, 5), gradient(at: 5, 10, in: multiply_swapGeneric)) + + // Test function with multiple `inout` parameters and a formal result. + func swapAndReturnProduct(_ x: inout Float, _ y: inout Float) -> Float { + let tmp = x + x = y + y = tmp + return x * y + } + func multiply_swapAndReturnProduct(_ x: Float, _ y: Float) -> Float { + var x2 = x + var y2 = y + let result = swapAndReturnProduct(&x2, &y2) + return result + } + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapAndReturnProduct)) + expectEqual((4, 3), gradient(at: 3, 4, in: multiply_swapAndReturnProduct)) } SimpleMathTests.test("CaptureLocal") {