Skip to content

Commit 0763e4b

Browse files
authored
Diagnose differentiable functions returning Void w/o inout arguments. (#63080)
Such functions are not differentiable and therefore should be rejected. Fixes #62923, fixes #58095
1 parent e916781 commit 0763e4b

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

include/swift/AST/DiagnosticsSema.def

+6
Original file line numberDiff line numberDiff line change
@@ -5337,6 +5337,12 @@ ERROR(differentiable_function_type_invalid_result,none,
53375337
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
53385338
"function type is '@differentiable%select{|(_linear)}1'",
53395339
(StringRef, bool))
5340+
ERROR(differentiable_function_type_void_result,
5341+
none,
5342+
"'@differentiable' function returning Void must have at least one "
5343+
"differentiable inout parameter, i.e. a non-'@noDerivative' parameter "
5344+
"whose type conforms to 'Differentiable'",
5345+
())
53405346
ERROR(differentiable_function_type_no_differentiability_parameters,
53415347
none,
53425348
"'@differentiable' function type requires at least one differentiability "

lib/Sema/TypeChecker.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,28 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
720720
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
721721
}
722722
}
723+
724+
// If the result type is void, we need to have at least one differentiable
725+
// inout argument
726+
if (result->isVoid() &&
727+
llvm::find_if(params,
728+
[&](AnyFunctionType::Param param) {
729+
if (param.isNoDerivative())
730+
return false;
731+
return param.isInOut() &&
732+
TypeChecker::isDifferentiable(param.getPlainType(),
733+
/*tangentVectorEqualsSelf*/ isLinear,
734+
dc, stage);
735+
}) == params.end()) {
736+
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
737+
auto resultStr = fnTy->getResult()->getString();
738+
auto diag = ctx.Diags.diagnose(
739+
diagLoc, diag::differentiable_function_type_void_result);
740+
hadAnyError = true;
741+
742+
if (repr)
743+
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
744+
}
723745
}
724746

725747
return hadAnyError;

test/AutoDiff/Sema/differentiable_func_type.swift

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ let _: @differentiable(_linear) (Float) -> NonDiffType
4141

4242
let _: @differentiable(_linear) (Float) -> Float
4343

44+
// expected-error @+1 {{'@differentiable' function returning Void must have at least one differentiable inout parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
45+
let _: @differentiable(reverse) (Float) -> Void
46+
let _: @differentiable(reverse) (inout Float) -> Void // okay
47+
4448
// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
4549
func test1<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Float) {}
4650
// expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}

0 commit comments

Comments
 (0)