Skip to content

Fix #2578 Part 1: Tighten type checking of pattern bindings #6389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 7, 2019
Merged
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ object desugar {
*/
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key

/** An attachment for match expressions generated from a PatDef */
val PatDefMatch: Property.Key[Unit] = new Property.Key

/** Info of a variable in a pattern: The named tree and its type */
private type VarInfo = (NameTree, Tree)

Expand Down Expand Up @@ -956,7 +959,11 @@ object desugar {
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)

def rhsUnchecked = makeAnnotated("scala.unchecked", rhs)
def rhsUnchecked = {
val rhs1 = makeAnnotated("scala.unchecked", rhs)
rhs1.pushAttachment(PatDefMatch, ())
rhs1
}
val vars =
if (tupleOptimizable) // include `_`
pat match {
Expand Down
14 changes: 9 additions & 5 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1528,13 +1528,17 @@ object Types {
*/
def signature(implicit ctx: Context): Signature = Signature.NotAMethod

def dropRepeatedAnnot(implicit ctx: Context): Type = this match {
case AnnotatedType(parent, annot) if annot.symbol eq defn.RepeatedAnnot => parent
case tp @ AnnotatedType(parent, annot) =>
tp.derivedAnnotatedType(parent.dropRepeatedAnnot, annot)
case tp => tp
/** Drop annotation of given `cls` from this type */
def dropAnnot(cls: Symbol)(implicit ctx: Context): Type = stripTypeVar match {
case self @ AnnotatedType(pre, annot) =>
if (annot.symbol eq cls) pre
else self.derivedAnnotatedType(pre.dropAnnot(cls), annot)
case _ =>
this
}

def dropRepeatedAnnot(implicit ctx: Context): Type = dropAnnot(defn.RepeatedAnnot)

def annotatedToRepeated(implicit ctx: Context): Type = this match {
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.annotatedToRepeated)
case AnnotatedType(tp, annot) if annot matches defn.RepeatedAnnot =>
Expand Down
59 changes: 41 additions & 18 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1200,14 +1200,15 @@ object Parsers {
* | ForExpr
* | [SimpleExpr `.'] id `=' Expr
* | SimpleExpr1 ArgumentExprs `=' Expr
* | PostfixExpr [Ascription]
* | [‘inline’] PostfixExpr `match' `{' CaseClauses `}'
* | Expr2
* | [‘inline’] Expr2 `match' `{' CaseClauses `}'
* | `implicit' `match' `{' ImplicitCaseClauses `}'
* Bindings ::= `(' [Binding {`,' Binding}] `)'
* Binding ::= (id | `_') [`:' Type]
* Ascription ::= `:' CompoundType
* | `:' Annotation {Annotation}
* | `:' `_' `*'
* Bindings ::= `(' [Binding {`,' Binding}] `)'
* Binding ::= (id | `_') [`:' Type]
* Expr2 ::= PostfixExpr [Ascription]
* Ascription ::= `:' InfixType
* | `:' Annotation {Annotation}
* | `:' `_' `*'
*/
val exprInParens: () => Tree = () => expr(Location.InParens)

Expand Down Expand Up @@ -1324,15 +1325,16 @@ object Parsers {
t
}
case COLON =>
ascription(t, location)
in.nextToken()
val t1 = ascription(t, location)
if (in.token == MATCH) expr1Rest(t1, location) else t1
case MATCH =>
matchExpr(t, startOffset(t), Match)
case _ =>
t
}

def ascription(t: Tree, location: Location.Value): Tree = atSpan(startOffset(t)) {
in.skipToken()
in.token match {
case USCORE =>
val uscoreStart = in.skipToken()
Expand Down Expand Up @@ -1801,7 +1803,10 @@ object Parsers {
*/
def pattern1(): Tree = {
val p = pattern2()
if (isVarPattern(p) && in.token == COLON) ascription(p, Location.InPattern)
if (isVarPattern(p) && in.token == COLON) {
in.nextToken()
ascription(p, Location.InPattern)
}
else p
}

Expand Down Expand Up @@ -2353,14 +2358,32 @@ object Parsers {
tmplDef(start, mods)
}

/** PatDef ::= Pattern2 {`,' Pattern2} [`:' Type] `=' Expr
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
* ValDcl ::= id {`,' id} `:' Type
* VarDcl ::= id {`,' id} `:' Type
/** PatDef ::= ids [‘:’ Type] ‘=’ Expr
* | Pattern2 [‘:’ Type | Ascription] ‘=’ Expr
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
* ValDcl ::= id {`,' id} `:' Type
* VarDcl ::= id {`,' id} `:' Type
*/
def patDefOrDcl(start: Offset, mods: Modifiers): Tree = atSpan(start, nameStart) {
val lhs = commaSeparated(pattern2)
val tpt = typedOpt()
val first = pattern2()
var lhs = first match {
case id: Ident if in.token == COMMA =>
in.nextToken()
id :: commaSeparated(() => termIdent())
case _ =>
first :: Nil
}
def emptyType = TypeTree().withSpan(Span(in.lastOffset))
val tpt =
if (in.token == COLON) {
in.nextToken()
if (in.token == AT && lhs.tail.isEmpty) {
lhs = ascription(first, Location.ElseWhere) :: Nil
emptyType
}
else toplevelTyp()
}
else emptyType
val rhs =
if (tpt.isEmpty || in.token == EQUALS) {
accept(EQUALS)
Expand All @@ -2374,9 +2397,9 @@ object Parsers {
lhs match {
case (id: BackquotedIdent) :: Nil if id.name.isTermName =>
finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start)
case Ident(name: TermName) :: Nil => {
case Ident(name: TermName) :: Nil =>
finalizeDef(ValDef(name, tpt, rhs), mods, start)
} case _ =>
case _ =>
PatDef(mods, lhs, tpt, rhs)
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
if (selType <:< unapplyArgType) {
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
selType
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
Expand Down
47 changes: 46 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ import ProtoTypes._
import Scopes._
import CheckRealizable._
import ErrorReporting.errorTree
import rewrites.Rewrites.patch
import util.Spans.Span

import util.SourcePosition
import transform.SymUtils._
import Decorators._
import ErrorReporting.{err, errorType}
import config.Printers.typr
import config.Printers.{typr, patmatch}
import NameKinds.DefaultGetterName
import Applications.unapplyArgs

import collection.mutable
import SymDenotations.{NoCompleter, NoDenotation}
Expand Down Expand Up @@ -594,6 +597,48 @@ trait Checking {
ctx.error(ex"$cls cannot be instantiated since it${rstatus.msg}", pos)
}

/** Check that pattern `pat` is irrefutable for scrutinee tye `pt`.
* This means `pat` is either marked @unchecked or `pt` conforms to the
* pattern's type. If pattern is an UnApply, do the check recursively.
*/
def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = {
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")

def check(pat: Tree, pt: Type): Boolean = {
if (pt <:< pat.tpe)
true
else {
ctx.errorOrMigrationWarning(
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)}
|
|If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""",
pat.sourcePos)
false
}
}

!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
pat match {
case Bind(_, pat1) =>
checkIrrefutable(pat1, pt)
case UnApply(fn, _, pats) =>
check(pat, pt) && {
val argPts = unapplyArgs(fn.tpe.finalResultType, fn, pats, pat.sourcePos)
pats.corresponds(argPts)(checkIrrefutable)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For UnApply, we should also check if the extractor itself is irrefutable, i.e. if it returns Some, or true, or a product. The logic to do that is in SpaceEngine#irrefutable.

Additionally, for some reason unapplies are simply rejected as val definitions:

scala> {
     |   object Positive { def unapply(i: Int): Option[Int] = Some(i).filter(_ > 0) }
     |   val Positive(p) = 5
     |   5 match { case Positive(p) => p }
     | }
3 |  val Positive(p) = 5
  |      ^^^^^^^^^^^
  | ((i: Int): Option[Int])(Positive.unapply) is not a valid result type of an unapply method of an extractor.

case Alternative(pats) =>
pats.forall(checkIrrefutable(_, pt))
case Typed(arg, tpt) =>
check(pat, pt) && checkIrrefutable(arg, pt)
case Ident(nme.WILDCARD) =>
true
case _ =>
check(pat, pt)
}
}
}

/** Check that `path` is a legal prefix for an import or export clause */
def checkLegalImportPath(path: Tree)(implicit ctx: Context): Unit = {
checkStable(path.tpe, path.sourcePos)
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ object ErrorReporting {
}
"""\$\{\w*\}""".r.replaceSomeIn(raw, m => translate(m.matched.drop(2).init))
}

def rewriteNotice: String =
if (ctx.scala2Mode) "\nThis patch can be inserted automatically under -rewrite."
else ""
}

def err(implicit ctx: Context): Errors = new Errors
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,15 @@ class Typer extends Namer
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
val sel1 = typedExpr(tree.selector)
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
result match {
case Match(sel, CaseDef(pat, _, _) :: _)
if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) =>
if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode)
patch(Span(pat.span.end), ": @unchecked")
case _ =>
}
result
}
}

Expand Down Expand Up @@ -1817,8 +1825,11 @@ class Typer extends Namer
}
case _ => arg1
}
val tpt = TypeTree(AnnotatedType(arg1.tpe.widenIfUnstable, Annotation(annot1)))
assignType(cpy.Typed(tree)(arg2, tpt), tpt)
val argType =
if (arg1.isInstanceOf[Bind]) arg1.tpe.widen // bound symbol is not accessible outside of Bind node
else arg1.tpe.widenIfUnstable
val annotatedTpt = TypeTree(AnnotatedType(argType, Annotation(annot1)))
assignType(cpy.Typed(tree)(arg2, annotatedTpt), annotatedTpt)
}
}

Expand Down
3 changes: 0 additions & 3 deletions compiler/test-resources/repl/patdef
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@ scala> val _ @ List(x) = List(1)
val x: Int = 1
scala> val List(_ @ List(x)) = List(List(2))
val x: Int = 2
scala> val B @ List(), C: List[Int] = List()
val B: List[Int] = List()
val C: List[Int] = List()
7 changes: 4 additions & 3 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class CompilationTests extends ParallelTesting {
aggregateTests(
compileFilesInDir("tests/neg", defaultOptions),
compileFilesInDir("tests/neg-tailcall", defaultOptions),
compileFilesInDir("tests/neg-strict", defaultOptions.and("-strict")),
compileFilesInDir("tests/neg-no-kind-polymorphism", defaultOptions and "-Yno-kind-polymorphism"),
compileFilesInDir("tests/neg-custom-args/deprecation", defaultOptions.and("-Xfatal-warnings", "-deprecation")),
compileFilesInDir("tests/neg-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")),
Expand All @@ -160,8 +161,6 @@ class CompilationTests extends ParallelTesting {
compileFile("tests/neg-custom-args/i3246.scala", scala2Mode),
compileFile("tests/neg-custom-args/overrideClass.scala", scala2Mode),
compileFile("tests/neg-custom-args/autoTuplingTest.scala", defaultOptions.and("-language:noAutoTupling")),
compileFile("tests/neg-custom-args/i1050.scala", defaultOptions.and("-strict")),
compileFile("tests/neg-custom-args/nullless.scala", defaultOptions.and("-strict")),
compileFile("tests/neg-custom-args/nopredef.scala", defaultOptions.and("-Yno-predef")),
compileFile("tests/neg-custom-args/noimports.scala", defaultOptions.and("-Yno-imports")),
compileFile("tests/neg-custom-args/noimports2.scala", defaultOptions.and("-Yno-imports")),
Expand Down Expand Up @@ -249,7 +248,9 @@ class CompilationTests extends ParallelTesting {

val lib =
compileList("src", librarySources,
defaultOptions.and("-Ycheck-reentrant", "-strict", "-priorityclasspath", defaultOutputDir))(libGroup)
defaultOptions.and("-Ycheck-reentrant",
// "-strict", // TODO: re-enable once we allow : @unchecked in pattern definitions. Right now, lots of narrowing pattern definitions fail.
"-priorityclasspath", defaultOutputDir))(libGroup)

val compilerSources = sources(Paths.get("compiler/src"))
val compilerManagedSources = sources(Properties.dottyCompilerManagedSources)
Expand Down
10 changes: 6 additions & 4 deletions docs/docs/internals/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ Expr1 ::= ‘if’ ‘(’ Expr ‘)’ {nl}
| ForExpr
| [SimpleExpr ‘.’] id ‘=’ Expr Assign(expr, expr)
| SimpleExpr1 ArgumentExprs ‘=’ Expr Assign(expr, expr)
| PostfixExpr [Ascription]
| [‘inline’] PostfixExpr ‘match’ ‘{’ CaseClauses ‘}’ Match(expr, cases) -- point on match
| Expr2
| [‘inline’] Expr2 ‘match’ ‘{’ CaseClauses ‘}’ Match(expr, cases) -- point on match
| ‘implicit’ ‘match’ ‘{’ ImplicitCaseClauses ‘}’
Expr2 ::= PostfixExpr [Ascription]
Ascription ::= ‘:’ InfixType Typed(expr, tp)
| ‘:’ Annotation {Annotation} Typed(expr, Annotated(EmptyTree, annot)*)
Catches ::= ‘catch’ Expr
Expand All @@ -224,7 +225,7 @@ SimpleExpr1 ::= Literal
Quoted ::= ‘'’ ‘{’ Block ‘}’
| ‘'’ ‘[’ Type ‘]’
ExprsInParens ::= ExprInParens {‘,’ ExprInParens}
ExprInParens ::= PostfixExpr ‘:’ Type
ExprInParens ::= PostfixExpr ‘:’ Type -- normal Expr allows only RefinedType here
| Expr
ParArgumentExprs ::= ‘(’ ExprsInParens ‘)’ exprs
| ‘(’ [ExprsInParens ‘,’] PostfixExpr ‘:’ ‘_’ ‘*’ ‘)’ exprs :+ Typed(expr, Ident(wildcardStar))
Expand Down Expand Up @@ -358,7 +359,8 @@ Def ::= ‘val’ PatDef
| ‘type’ {nl} TypeDcl
| TmplDef
| INT
PatDef ::= Pattern2 {‘,’ Pattern2} [‘:’ Type] ‘=’ Expr PatDef(_, pats, tpe?, expr)
PatDef ::= ids [‘:’ Type] ‘=’ Expr
| Pattern2 [‘:’ Type | Ascription] ‘=’ Expr PatDef(_, pats, tpe?, expr)
VarDef ::= PatDef
| ids ‘:’ Type ‘=’ ‘_’
DefDef ::= DefSig [(‘:’ | ‘<:’) Type] ‘=’ Expr DefDef(_, name, tparams, vparamss, tpe, expr)
Expand Down
File renamed without changes.
File renamed without changes.
10 changes: 10 additions & 0 deletions tests/neg-strict/unchecked-patterns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Test {
val (y1: Some[Int] @unchecked) = Some(1): Option[Int] // OK
val y2: Some[Int] @unchecked = Some(1): Option[Int] // error

val x :: xs = List(1, 2, 3) // error
val (1, c) = (1, 2) // error
val 1 *: cs = 1 *: () // error

val (_: Int | _: Any) = ??? : Any // error
}
4 changes: 4 additions & 0 deletions tests/neg/multi-patterns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
object Test {
val (a :: as), bs = List(1, 2, 3) // error
val B @ List(), C: List[Int] = List() // error
}
5 changes: 5 additions & 0 deletions tests/pos-special/fatal-warnings/unchecked-scrutinee.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object Test {
List(1: @unchecked, 2, 3): @unchecked match {
case a :: as =>
}
}
2 changes: 0 additions & 2 deletions tests/pos/i3412.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
class Test {
val A @ List() = List()
val B @ List(), C: List[Int] = List()
val D @ List(), E @ List() = List()
}
9 changes: 9 additions & 0 deletions tests/run/unchecked-patterns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object Test extends App {
val x: Int @unchecked = 2
val (y1: Some[Int] @unchecked) = Some(1): Option[Int]

val a :: as: @unchecked = List(1, 2, 3)
val lst @ b :: bs: @unchecked = List(1, 2, 3)
val (1, c): @unchecked = (1, 2)
val 1 *: cs: @unchecked = 1 *: () // error
}