diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 7e4f79047a63b..54f3effa03ad1 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -5334,6 +5334,12 @@ ERROR(differentiable_function_type_invalid_result,none, "%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing " "function type is '@differentiable%select{|(_linear)}1'", (StringRef, bool)) +ERROR(differentiable_function_type_void_result, + none, + "'@differentiable' function returning Void must have at least one " + "differentiable inout parameter, i.e. a non-'@noDerivative' parameter " + "whose type conforms to 'Differentiable'", + ()) ERROR(differentiable_function_type_no_differentiability_parameters, none, "'@differentiable' function type requires at least one differentiability " diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp index 8c74e0142838a..aa92ed614f95e 100644 --- a/lib/Sema/TypeChecker.cpp +++ b/lib/Sema/TypeChecker.cpp @@ -720,6 +720,28 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc, diag.highlight((*repr)->getResultTypeRepr()->getSourceRange()); } } + + // If the result type is void, we need to have at least one differentiable + // inout argument + if (result->isVoid() && + llvm::find_if(params, + [&](AnyFunctionType::Param param) { + if (param.isNoDerivative()) + return false; + return param.isInOut() && + TypeChecker::isDifferentiable(param.getPlainType(), + /*tangentVectorEqualsSelf*/ isLinear, + dc, stage); + }) == params.end()) { + auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc; + auto resultStr = fnTy->getResult()->getString(); + auto diag = ctx.Diags.diagnose( + diagLoc, diag::differentiable_function_type_void_result); + hadAnyError = true; + + if (repr) + diag.highlight((*repr)->getResultTypeRepr()->getSourceRange()); + } } return hadAnyError; diff --git a/test/AutoDiff/Sema/differentiable_func_type.swift b/test/AutoDiff/Sema/differentiable_func_type.swift index b904e4a9bf384..2bc467bac7dfd 100644 --- a/test/AutoDiff/Sema/differentiable_func_type.swift +++ b/test/AutoDiff/Sema/differentiable_func_type.swift @@ -41,6 +41,10 @@ let _: @differentiable(_linear) (Float) -> NonDiffType let _: @differentiable(_linear) (Float) -> Float +// expected-error @+1 {{'@differentiable' function returning Void must have at least one differentiable inout parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}} +let _: @differentiable(reverse) (Float) -> Void +let _: @differentiable(reverse) (inout Float) -> Void // okay + // expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}} func test1(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Float) {} // expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}