diff --git a/lib/AST/ASTScopeCreation.cpp b/lib/AST/ASTScopeCreation.cpp index 4eee2035a67fe..66821d17e6c49 100644 --- a/lib/AST/ASTScopeCreation.cpp +++ b/lib/AST/ASTScopeCreation.cpp @@ -100,13 +100,14 @@ class ScopeCreator final : public ASTAllocated { ASTScopeAssert(expr, "If looking for closures, must have an expression to search."); - /// AST walker that finds top-level closures in an expression. - class ClosureFinder : public ASTWalker { + /// AST walker that finds nested scopes in expressions. This handles both + /// closures and if/switch expressions. + class NestedExprScopeFinder : public ASTWalker { ScopeCreator &scopeCreator; ASTScopeImpl *parent; public: - ClosureFinder(ScopeCreator &scopeCreator, ASTScopeImpl *parent) + NestedExprScopeFinder(ScopeCreator &scopeCreator, ASTScopeImpl *parent) : scopeCreator(scopeCreator), parent(parent) {} PreWalkResult walkToExprPre(Expr *E) override { @@ -122,6 +123,13 @@ class ScopeCreator final : public ASTAllocated { parent, capture); return Action::SkipChildren(E); } + + // If we have a single value statement expression, we need to add any + // scopes in the underlying statement. + if (auto *SVE = dyn_cast(E)) { + scopeCreator.addToScopeTree(SVE->getStmt(), parent); + return Action::SkipChildren(E); + } return Action::Continue(E); } PreWalkResult walkToStmtPre(Stmt *S) override { @@ -148,7 +156,7 @@ class ScopeCreator final : public ASTAllocated { } }; - expr->walk(ClosureFinder(*this, parent)); + expr->walk(NestedExprScopeFinder(*this, parent)); } public: @@ -518,11 +526,6 @@ class NodeAdder if (!expr) return p; - // If we have a single value statement expression, we expand scopes based - // on the underlying statement. - if (auto *SVE = dyn_cast(expr)) - return visit(SVE->getStmt(), p, scopeCreator); - scopeCreator.addExprToScopeTree(expr, p); return p; } diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index 730776dcb88b6..c5ae76ac144f4 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -3842,6 +3842,23 @@ class SingleValueStmtUsageChecker final : public ASTWalker { return MacroWalking::Expansion; } + AssignExpr *findAssignment(Expr *E) const { + // Don't consider assignments if we have a parent expression (as otherwise + // this would be effectively allowing it in an arbitrary expression + // position). + if (Parent.getAsExpr()) + return nullptr; + + // Look through optional exprs, which are present for e.g x?.y = z, as + // we wrap the entire assign in the optional evaluation of the destination. + if (auto *OEE = dyn_cast(E)) { + E = OEE->getSubExpr(); + while (auto *IIO = dyn_cast(E)) + E = IIO->getSubExpr(); + } + return dyn_cast(E); + } + PreWalkResult walkToExprPre(Expr *E) override { if (auto *SVE = dyn_cast(E)) { // Diagnose a SingleValueStmtExpr in a context that we do not currently @@ -3917,13 +3934,9 @@ class SingleValueStmtUsageChecker final : public ASTWalker { return Action::Continue(E); } - // Valid as the source of an assignment, as long as it's not a nested - // expression (as otherwise this would be effectively allowing it in an - // arbitrary expression position). - if (auto *AE = dyn_cast(E)) { - if (!Parent.getAsExpr()) - markValidSingleValueStmt(AE->getSrc()); - } + // Valid as the source of an assignment. + if (auto *AE = findAssignment(E)) + markValidSingleValueStmt(AE->getSrc()); // Valid as a single expression body of a closure. This is needed in // addition to ReturnStmt checking, as we will remove the return if the diff --git a/test/SILGen/if_expr.swift b/test/SILGen/if_expr.swift index 182f7c74f6154..b85152bccce87 100644 --- a/test/SILGen/if_expr.swift +++ b/test/SILGen/if_expr.swift @@ -263,3 +263,240 @@ func nestedType() throws -> Int { 0 } } + +// MARK: Bindings + +enum E { + case e(Int) +} + +struct S { + var i: Int + var opt: Int? + + var computed: Int { + get { i } + set { i = newValue } + } + var coroutined: Int { + _read { yield i } + _modify { yield &i } + } + + subscript(x: Int) -> Int { + get { i } + set { i = newValue } + } + + mutating func testAssign1(_ x: E) { + i = if case .e(let y) = x { y } else { 0 } + } + + + mutating func testAssign2(_ x: E) { + i = if case .e(let y) = x { Int(y) } else { 0 } + } + + func testAssign3(_ x: E) { + var i = 0 + i = if case .e(let y) = x { y } else { 0 } + _ = i + } + + func testAssign4(_ x: E) { + var i = 0 + let _ = { + i = if case .e(let y) = x { y } else { 0 } + } + _ = i + } + + mutating func testAssign5(_ x: E) { + i = switch Bool.random() { + case true: + if case .e(let y) = x { y } else { 0 } + case let z: + z ? 0 : 1 + } + } + + mutating func testAssign6(_ x: E) { + i = if case .e(let y) = x { + switch Bool.random() { + case true: y + case false: y + } + } else { + 0 + } + } + + mutating func testAssign7(_ x: E?) { + i = if let x = x { + switch x { + case .e(let y): y + } + } else { + 0 + } + } + + func testReturn1(_ x: E) -> Int { + if case .e(let y) = x { y } else { 0 } + } + + func testReturn2(_ x: E) -> Int { + return if case .e(let y) = x { y } else { 0 } + } + + func testReturn3(_ x: E) -> Int { + { + if case .e(let y) = x { y } else { 0 } + }() + } + + func testReturn4(_ x: E) -> Int { + return { + if case .e(let y) = x { y } else { 0 } + }() + } + + func testBinding1(_ x: E) -> Int { + let i = if case .e(let y) = x { y } else { 0 } + return i + } + + func testBinding2(_ x: E) -> Int { + let i = { + if case .e(let y) = x { y } else { 0 } + }() + return i + } +} + +enum G { + case e(Int) + case f +} + +struct TestLValues { + var s: S + var opt: S? + var optopt: S?? + + mutating func testOptPromote1() { + opt = if .random() { s } else { s } + } + + mutating func testOptPromote2() { + optopt = if .random() { s } else { s } + } + + mutating func testStored1() { + s.i = if .random() { 1 } else { 0 } + } + + mutating func testStored2() throws { + s.i = if .random() { 1 } else { throw Err() } + } + + mutating func testComputed1() { + s.computed = if .random() { 1 } else { 0 } + } + + mutating func testComputed2() throws { + s.computed = if .random() { 1 } else { throw Err() } + } + + mutating func testCoroutined1() { + s.coroutined = if .random() { 1 } else { 0 } + } + + mutating func testCoroutined2() throws { + s.coroutined = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalChain1() { + opt?.i = if .random() { 1 } else { 0 } + } + + mutating func testOptionalChain2() throws { + opt?.i = if .random() { throw Err() } else { 0 } + } + + mutating func testOptionalChain3(_ g: G) { + opt?.i = if case .e(let i) = g { i } else { 0 } + } + + mutating func testOptionalChain4(_ g: G) throws { + opt?.i = if case .e(let i) = g { i } else { throw Err() } + } + + mutating func testOptionalChain5(_ g: G) throws { + opt?.computed = if case .e(let i) = g { i } else { throw Err() } + } + + mutating func testOptionalChain6(_ g: G) throws { + opt?.coroutined = if case .e(let i) = g { i } else { throw Err() } + } + + mutating func testOptionalChain7() throws { + optopt??.i = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalChain8() throws { + optopt??.opt = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalChain9() throws { + optopt??.opt? = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalForce1() throws { + opt!.i = if .random() { throw Err() } else { 0 } + } + + mutating func testOptionalForce2() throws { + opt!.computed = if .random() { throw Err() } else { 0 } + } + + mutating func testOptionalForce3(_ g: G) throws { + opt!.coroutined = if case .e(let i) = g { i } else { throw Err() } + } + + mutating func testOptionalForce4() throws { + optopt!!.i = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalForce5() throws { + optopt!!.opt = if .random() { 1 } else { throw Err() } + } + + mutating func testOptionalForce6() throws { + optopt!!.opt! = if .random() { 1 } else { throw Err() } + } + + mutating func testSubscript1() throws { + s[5] = if .random() { 1 } else { throw Err() } + } + + mutating func testSubscript2() throws { + opt?[5] = if .random() { 1 } else { throw Err() } + } + + mutating func testSubscript3() throws { + opt![5] = if .random() { 1 } else { throw Err() } + } + + mutating func testKeyPath1(_ kp: WritableKeyPath) throws { + s[keyPath: kp] = if .random() { 1 } else { throw Err() } + } + + mutating func testKeyPath2(_ kp: WritableKeyPath) throws { + opt?[keyPath: kp] = if .random() { 1 } else { throw Err() } + } + + mutating func testKeyPath3(_ kp: WritableKeyPath) throws { + opt![keyPath: kp] = if .random() { 1 } else { throw Err() } + } +} diff --git a/test/SILGen/switch_expr.swift b/test/SILGen/switch_expr.swift index 736872382f700..7275aa1bb534d 100644 --- a/test/SILGen/switch_expr.swift +++ b/test/SILGen/switch_expr.swift @@ -358,3 +358,250 @@ func nestedType() throws -> Int { 0 } } + +// MARK: Bindings + +enum F { + case e(Int) +} + +struct S { + var i: Int + var opt: Int? + + var computed: Int { + get { i } + set { i = newValue } + } + var coroutined: Int { + _read { yield i } + _modify { yield &i } + } + + subscript(x: Int) -> Int { + get { i } + set { i = newValue } + } + + mutating func testAssign1(_ x: F) { + i = switch x { + case .e(let y): y + } + } + + mutating func testAssign2(_ x: F) { + i = switch x { + case .e(let y): Int(y) + } + } + + func testAssign3(_ x: F) { + var i = 0 + i = switch x { + case .e(let y): y + } + _ = i + } + + func testAssign4(_ x: F) { + var i = 0 + let _ = { + i = switch x { + case .e(let y): y + } + } + _ = i + } + + mutating func testAssign5(_ x: F) { + i = switch Bool.random() { + case true: + switch x { + case .e(let y): y + } + case let z: + z ? 0 : 1 + } + } + + mutating func testAssign6(_ x: F) { + i = switch x { + case .e(let y): + switch Bool.random() { + case true: y + case false: y + } + } + } + + func testReturn1(_ x: F) -> Int { + switch x { + case .e(let y): y + } + } + + func testReturn2(_ x: F) -> Int { + return switch x { + case .e(let y): y + } + } + + func testReturn3(_ x: F) -> Int { + { + switch x { + case .e(let y): y + } + }() + } + + func testReturn4(_ x: F) -> Int { + return { + switch x { + case .e(let y): y + } + }() + } + + func testBinding1(_ x: F) -> Int { + let i = switch x { + case .e(let y): y + } + return i + } + + func testBinding2(_ x: F) -> Int { + let i = { + switch x { + case .e(let y): y + } + }() + return i + } +} + +enum G { + case e(Int) + case f +} + +struct TestLValues { + var s: S + var opt: S? + var optopt: S?? + + mutating func testOptPromote1() { + opt = switch Bool.random() { case true: s case false: s } + } + + mutating func testOptPromote2() { + optopt = switch Bool.random() { case true: s case false: s } + } + + mutating func testStored1() { + s.i = switch Bool.random() { case true: 1 case false: 0 } + } + + mutating func testStored2() throws { + s.i = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testComputed1() { + s.computed = switch Bool.random() { case true: 1 case false: 0 } + } + + mutating func testComputed2() throws { + s.computed = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testCoroutined1() { + s.coroutined = switch Bool.random() { case true: 1 case false: 0 } + } + + mutating func testCoroutined2() throws { + s.coroutined = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalChain1() { + opt?.i = switch Bool.random() { case true: 1 case false: 0 } + } + + mutating func testOptionalChain2() throws { + opt?.i = switch Bool.random() { case true: throw Err() case false: 0 } + } + + mutating func testOptionalChain3(_ g: G) { + opt?.i = switch g { case .e(let i): i default: 0 } + } + + mutating func testOptionalChain4(_ g: G) throws { + opt?.i = switch g { case .e(let i): i default: throw Err() } + } + + mutating func testOptionalChain5(_ g: G) throws { + opt?.computed = switch g { case .e(let i): i default: throw Err() } + } + + mutating func testOptionalChain6(_ g: G) throws { + opt?.coroutined = switch g { case .e(let i): i default: throw Err() } + } + + mutating func testOptionalChain7() throws { + optopt??.i = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalChain8() throws { + optopt??.opt = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalChain9() throws { + optopt??.opt? = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalForce1() throws { + opt!.i = switch Bool.random() { case true: throw Err() case false: 0 } + } + + mutating func testOptionalForce2() throws { + opt!.computed = switch Bool.random() { case true: throw Err() case false: 0 } + } + + mutating func testOptionalForce3(_ g: G) throws { + opt!.coroutined = switch g { case .e(let i): i default: throw Err() } + } + + mutating func testOptionalForce4() throws { + optopt!!.i = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalForce5() throws { + optopt!!.opt = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testOptionalForce6() throws { + optopt!!.opt! = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testSubscript1() throws { + s[5] = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testSubscript2() throws { + opt?[5] = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testSubscript3() throws { + opt![5] = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testKeyPath1(_ kp: WritableKeyPath) throws { + s[keyPath: kp] = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testKeyPath2(_ kp: WritableKeyPath) throws { + opt?[keyPath: kp] = switch Bool.random() { case true: 1 case false: throw Err() } + } + + mutating func testKeyPath3(_ kp: WritableKeyPath) throws { + opt![keyPath: kp] = switch Bool.random() { case true: 1 case false: throw Err() } + } +}