Skip to content

[AutoDiff][Sema] Properly type-check differentiability of parameters and results in differentiable function type declarations #41174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
15 changes: 11 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3462,8 +3462,8 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
ERROR(autodiff_attr_original_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot differentiate functions with both an 'inout' parameter and a "
"result", ())
"cannot differentiate functions with both a differentiable 'inout' "
"parameter and a differentiable result", ())
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
Expand Down Expand Up @@ -5040,12 +5040,19 @@ ERROR(differentiable_function_type_invalid_result,none,
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(_linear)}1'",
(StringRef, bool))
ERROR(differentiable_function_type_no_differentiability_parameters,
none,
ERROR(differentiable_function_type_multiple_semantic_results,none,
"'@differentiable' function type cannot have both a differentiable "
"'inout' parameter and a differentiable result", ())
ERROR(differentiable_function_type_no_differentiability_parameters,none,
"'@differentiable' function type requires at least one differentiability "
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
(/*isLinear*/ bool))
ERROR(differentiable_function_type_no_differentiable_result,none,
"'@differentiable' function type requires a differentiable result, i.e. "
"a non-'Void' type that conforms to 'Differentiable'%select{| with its "
"'TangentVector' equal to itself}0",
(/*isLinear*/ bool))

// SIL
ERROR(opened_non_protocol,none,
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6178,7 +6178,7 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
return tangentSpace;
};

// For tuple types: the tangent space is a tuple of the elements' tangent
// For tuple types: the tangent space is a tuple of the elements' tangent
// space types, for the elements that have a tangent space.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
Expand Down
46 changes: 40 additions & 6 deletions lib/Sema/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,17 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
dc, stage);
}) != params.end();
bool alreadyDiagnosedOneParam = false;
bool hasDifferentiableInoutParameter = false;
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
auto param = params[i];
if (param.isNoDerivative())
continue;
auto paramType = param.getPlainType();
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage))
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage)) {
if (param.isInOut())
hasDifferentiableInoutParameter = true;
continue;
}
auto diagLoc =
repr ? (*repr)->getArgsTypeRepr()->getElement(i).Type->getLoc() : loc;
auto paramTypeString = paramType->getString();
Expand All @@ -637,6 +641,7 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
if (hasValidDifferentiabilityParam)
diagnostic.fixItInsert(diagLoc, "@noDerivative ");
}

// Reject the case where all parameters have '@noDerivative'.
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
Expand All @@ -651,11 +656,27 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
}
}

// Check the result
bool differentiable = isDifferentiable(result,
/*tangentVectorEqualsSelf*/ isLinear,
dc, stage);
if (!differentiable) {
// Check the result.
bool resultExists = !(result->isVoid());
bool resultIsDifferentiable = TypeChecker::isDifferentiable(
result, /*tangentVectorEqualsSelf*/ isLinear, dc, stage);
bool differentiableResultExists = resultExists && resultIsDifferentiable;

// Reject the case where there are multiple semantic results.
if (differentiableResultExists && hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc,
diag::differentiable_function_type_multiple_semantic_results);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getSourceRange());
}
}

// Reject the case where the semantic result is not differentiable.
if (!resultIsDifferentiable && !hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto resultStr = fnTy->getResult()->getString();
auto diag = ctx.Diags.diagnose(
Expand All @@ -667,6 +688,19 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}

// Reject the case where there are no semantic results.
if (!resultExists && !hasDifferentiableInoutParameter) {
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
auto diag = ctx.Diags.diagnose(
diagLoc, diag::differentiable_function_type_no_differentiable_result,
isLinear);
hadAnyError = true;

if (repr) {
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
}
}
}

return hadAnyError;
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ extension ProtocolRequirementDerivative {
func multipleSemanticResults(_ x: inout Float) -> Float {
return x
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(x: inout Float) -> (
value: Float, pullback: (Float) -> Float
Expand Down Expand Up @@ -885,14 +885,14 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}

extension InoutParameters {
func multipleSemanticResults(_ x: inout Float) -> Float { x }
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: multipleSemanticResults)
func vjpMultipleSemanticResults(_ x: inout Float) -> (
value: Float, pullback: (inout Float) -> Void
) { fatalError() }

func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@derivative(of: inoutVoid)
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
value: Float, pullback: (inout Float) -> Void
Expand Down
10 changes: 5 additions & 5 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func two9(x: Float, y: Float) -> Float {
func inout1(x: Float, y: inout Float) -> Void {
let _ = x + y
}
// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse, wrt: y)
func inout2(x: Float, y: inout Float) -> Float {
let _ = x + y
Expand Down Expand Up @@ -670,11 +670,11 @@ final class FinalClass: Differentiable {
@differentiable(reverse, wrt: y)
func inoutVoid(x: Float, y: inout Float) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
func multipleSemanticResults(_ x: inout Float) -> Float { x }

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse, wrt: y)
func swap(x: inout Float, y: inout Float) {}

Expand All @@ -687,7 +687,7 @@ extension InoutParameters {
@differentiable(reverse)
static func staticMethod(_ lhs: inout Self, rhs: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {}
}
Expand All @@ -696,7 +696,7 @@ extension InoutParameters {
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) {}

// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}}
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
@differentiable(reverse)
mutating func mutatingMethod(_ other: Self) -> Self {}
}
Expand Down
13 changes: 12 additions & 1 deletion test/AutoDiff/Sema/differentiable_func_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ let _: @differentiable(reverse) (Float) throws -> Float

struct NonDiffType { var x: Int }

// FIXME: Properly type-check parameters and the result's differentiability
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
let _: @differentiable(reverse) (NonDiffType) -> Float

Expand All @@ -29,6 +28,12 @@ let _: @differentiable(reverse) (Float, NonDiffType) -> Float
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
let _: @differentiable(_linear) (Float) -> NonDiffType

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(reverse) (inout Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(_linear) (inout Float) -> Float

// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'; did you want to add '@noDerivative' to this parameter?}} {{41-41=@noDerivative }}
let _: @differentiable(_linear) (Float, NonDiffType) -> Float
Expand All @@ -41,6 +46,12 @@ let _: @differentiable(_linear) (Float) -> NonDiffType

let _: @differentiable(_linear) (Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(reverse) (inout Float) -> Float

// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
let _: @differentiable(_linear) (inout Float) -> Float

// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test1<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Float) {}
// expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// SR-15808: In AST, type checking skips a closure with non-differentiable input
// where `Void` is included as a parameter without being marked `@noDerivative`.
// It also crashes when the output is `Void` and no input is `inout`. As a
// result, the compiler crashes during Sema.
import _Differentiation

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (()) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (()) -> Float) {}

// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
func helloWorld(_ x: @differentiable(reverse) (Float) -> Void) {}

func helloWorld(_ x: @differentiable(reverse) (@noDerivative Float, Void) -> Float) {}

// Original crash:
// Assertion failed: (!parameterIndices->isEmpty() && "Parameter indices must not be empty"), function getAutoDiffDerivativeFunctionType, file SILFunctionType.cpp, line 800.
// Stack dump:
// ...
// 1. Apple Swift version 5.6-dev (LLVM 7b20e61dd04138a, Swift 9438cf6b2e83c5f)
// 2. Compiling with the current language version
// 3. While evaluating request ASTLoweringRequest(Lowering AST to SIL for file "/Users/philipturner/Desktop/Experimentation4/Experimentation4/main.swift")
// Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
// 0 swift-frontend 0x0000000108d7a5c0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 56
// 1 swift-frontend 0x0000000108d79820 llvm::sys::RunSignalHandlers() + 128
// 2 swift-frontend 0x0000000108d7ac24 SignalHandler(int) + 304
// 3 libsystem_platform.dylib 0x00000001bb5304e4 _sigtramp + 56
// 4 libsystem_pthread.dylib 0x00000001bb518eb0 pthread_kill + 288
// 5 libsystem_c.dylib 0x00000001bb456314 abort + 164
// 6 libsystem_c.dylib 0x00000001bb45572c err + 0
// 7 swift-frontend 0x0000000108d9ae3c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) (.cold.3) + 0
// 8 swift-frontend 0x0000000104abc35c swift::SILFunctionType::getAutoDiffDerivativeFunctionType(swift::IndexSubset*, swift::IndexSubset*, swift::AutoDiffDerivativeFunctionKind, swift::Lowering::TypeConverter&, llvm::function_ref<swift::ProtocolConformanceRef (swift::CanType, swift::Type, swift::ProtocolDecl*)>, swift::CanGenericSignature, bool, swift::CanType) + 152
// 9 swift-frontend 0x0000000104b496cc (anonymous namespace)::TypeClassifierBase<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*>::getNormalDifferentiableSILFunctionTypeRecursiveProperties(swift::CanTypeWrapper<swift::SILFunctionType>, swift::Lowering::AbstractionPattern) + 184
// 10 swift-frontend 0x0000000104b3b72c swift::CanTypeVisitor<(anonymous namespace)::LowerType, swift::Lowering::TypeLowering*, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t>::visit(swift::CanType, swift::Lowering::AbstractionPattern, swift::Lowering::IsTypeExpansionSensitive_t) + 1980
// 11 swift-frontend 0x0000000104b3c0e0 swift::Lowering::TypeConverter::getTypeLoweringForLoweredType(swift::Lowering::AbstractionPattern, swift::CanType, swift::TypeExpansionContext, swift::Lowering::IsTypeExpansionSensitive_t) + 648
// 12 swift-frontend 0x0000000104b3ae08 swift::Lowering::TypeConverter::getTypeLowering(swift::Lowering::AbstractionPattern, swift::Type, swift::TypeExpansionContext) + 708
// 13 swift-frontend 0x0000000104ac8544 (anonymous namespace)::DestructureInputs::visit(swift::ValueOwnership, bool, swift::Lowering::AbstractionPattern, swift::CanType, bool, bool) + 184
// 14 swift-frontend 0x0000000104ac6a1c getSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, (anonymous namespace)::Conventions const&, swift::ForeignInfo const&, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 2584
// 15 swift-frontend 0x0000000104ac5f98 getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>)::$_12::operator()((anonymous namespace)::Conventions const&) const + 316
// 16 swift-frontend 0x0000000104abf55c getNativeSILFunctionType(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::Lowering::AbstractionPattern, swift::CanTypeWrapper<swift::AnyFunctionType>, swift::SILExtInfoBuilder, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SILDeclRef>, llvm::Optional<swift::SubstitutionMap>, swift::ProtocolConformanceRef, llvm::Optional<llvm::SmallBitVector>) + 508
// 17 swift-frontend 0x0000000104ac0b44 getUncachedSILFunctionTypeForConstant(swift::Lowering::TypeConverter&, swift::TypeExpansionContext, swift::SILDeclRef, swift::Lowering::TypeConverter::LoweredFormalTypes) + 1920
// 18 swift-frontend 0x0000000104ac1474 swift::Lowering::TypeConverter::getConstantInfo(swift::TypeExpansionContext, swift::SILDeclRef) + 216
// 19 swift-frontend 0x0000000104ab9808 swift::SILFunctionBuilder::getOrCreateFunction(swift::SILLocation, swift::SILDeclRef, swift::ForDefinition_t, llvm::function_ref<swift::SILFunction* (swift::SILLocation, swift::SILDeclRef)>, swift::ProfileCounter) + 132
// 20 swift-frontend 0x0000000104f1d120 swift::Lowering::SILGenModule::getFunction(swift::SILDeclRef, swift::ForDefinition_t) + 328
// 21 swift-frontend 0x0000000104f2086c emitOrDelayFunction(swift::Lowering::SILGenModule&, swift::SILDeclRef, bool) + 344
// 22 swift-frontend 0x0000000104f1d828 swift::Lowering::SILGenModule::emitFunction(swift::FuncDecl*) + 140
// 23 swift-frontend 0x0000000104f2294c swift::ASTLoweringRequest::evaluate(swift::Evaluator&, swift::ASTLoweringDescriptor) const + 1612
// 24 swift-frontend 0x0000000104fcbca4 swift::SimpleRequest<swift::ASTLoweringRequest, std::__1::unique_ptr<swift::SILModule, std::__1::default_delete<swift::SILModule> > (swift::ASTLoweringDescriptor), (swift::RequestFlags)9>::evaluateRequest(swift::ASTLoweringRequest const&, swift::Evaluator&) + 156
// 25 swift-frontend 0x0000000104f2647c llvm::Expected<swift::ASTLoweringRequest::OutputType> swift::Evaluator::getResultUncached<swift::ASTLoweringRequest>(swift::ASTLoweringRequest const&) + 408
// 26 swift-frontend 0x0000000104f233b0 swift::performASTLowering(swift::FileUnit&, swift::Lowering::TypeConverter&, swift::SILOptions const&) + 104
// 27 swift-frontend 0x0000000104a12088 swift::performCompileStepsPostSema(swift::CompilerInstance&, int&, swift::FrontendObserver*) + 496
// 28 swift-frontend 0x0000000104a13d08 swift::performFrontend(llvm::ArrayRef<char const*>, char const*, void*, swift::FrontendObserver*) + 2936
// 29 swift-frontend 0x00000001049b213c swift::mainEntry(int, char const**) + 500
// 30 dyld 0x00000001113c90f4 start + 520
Loading