Skip to content
Merged
113 changes: 69 additions & 44 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,19 @@ object desugar {
*/
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key

/** An attachment for match expressions generated from a PatDef */
val PatDefMatch: Property.Key[Unit] = new Property.Key
/** An attachment for match expressions generated from a PatDef or GenFrom.
* Value of key == one of IrrefutablePatDef, IrrefutableGenFrom
*/
val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey

/** What static check should be applied to a Match (none, irrefutable, exhaustive) */
class MatchCheck(val n: Int) extends AnyVal
object MatchCheck {
val None = new MatchCheck(0)
val Exhaustive = new MatchCheck(1)
val IrrefutablePatDef = new MatchCheck(2)
val IrrefutableGenFrom = new MatchCheck(3)
}

/** Info of a variable in a pattern: The named tree and its type */
private type VarInfo = (NameTree, Tree)
Expand Down Expand Up @@ -926,6 +937,22 @@ object desugar {
}
}

/** The selector of a match, which depends of the given `checkMode`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that @unchecked is also added to irrefutable patterns really confused me until I understood that annotating the selector only affects exhaustivity checks and not irrefutability checks. Perhaps we could note this in the doc?

Suggested change
/** The selector of a match, which depends of the given `checkMode`.
/** The selector of a match, which depends of the given `checkMode`.
*
* The @unchecked annotation is added whenever `checkMode` is not `Exhaustive` to silence
* unnecessary inexhaustive match warnings.

* @param sel the original selector
* @return if `checkMode` is
* - None : sel @unchecked
* - Exhaustive : sel
* - IrrefutablePatDef,
* IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode`
*/
def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree =
if (checkMode == MatchCheck.Exhaustive) sel
else {
val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType)))
if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode)
sel1
}

/** If `pat` is a variable pattern,
*
* val/var/lazy val p = e
Expand Down Expand Up @@ -960,11 +987,6 @@ object desugar {
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)

def rhsUnchecked = {
val rhs1 = makeAnnotated("scala.unchecked", rhs)
rhs1.pushAttachment(PatDefMatch, ())
rhs1
}
val vars =
if (tupleOptimizable) // include `_`
pat match {
Expand All @@ -977,7 +999,7 @@ object desugar {
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
val matchExpr =
if (tupleOptimizable) rhs
else Match(rhsUnchecked, caseDef :: Nil)
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
vars match {
case Nil =>
matchExpr
Expand Down Expand Up @@ -1120,20 +1142,16 @@ object desugar {
*
* { cases }
* ==>
* x$1 => (x$1 @unchecked) match { cases }
* x$1 => (x$1 @unchecked?) match { cases }
*
* If `nparams` != 1, expand instead to
*
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked) match { cases }
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases }
*/
def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = {
def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = {
val params = (1 to nparams).toList.map(makeSyntheticParameter(_))
val selector = makeTuple(params.map(p => Ident(p.name)))

if (unchecked)
Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases))
else
Function(params, Match(selector, cases))
Function(params, Match(makeSelector(selector, checkMode), cases))
}

/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
Expand Down Expand Up @@ -1262,15 +1280,19 @@ object desugar {
*/
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {

/** Make a function value pat => body.
* If pat is a var pattern id: T then this gives (id: T) => body
* Otherwise this gives { case pat => body }
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
*/
def makeLambda(pat: Tree, body: Tree): Tree = pat match {
case IdPattern(named, tpt) =>
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
case _ =>
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil)
val matchCheckMode =
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
else MatchCheck.None
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
}

/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
Expand Down Expand Up @@ -1316,7 +1338,7 @@ object desugar {
val cases = List(
CaseDef(pat, EmptyTree, Literal(Constant(true))),
CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases))
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None))
}

/** Is pattern `pat` irrefutable when matched against `rhs`?
Expand All @@ -1342,41 +1364,47 @@ object desugar {
}
}

def isIrrefutableGenFrom(gen: GenFrom): Boolean =
gen.isInstanceOf[IrrefutableGenFrom] ||
IdPattern.unapply(gen.pat).isDefined ||
isIrrefutable(gen.pat, gen.expr)
def needsNoFilter(gen: GenFrom): Boolean =
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
false
else (
gen.checkMode != GenCheckMode.FilterNow ||
IdPattern.unapply(gen.pat).isDefined ||
isIrrefutable(gen.pat, gen.expr)
)

/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
* matched against `rhs`.
*/
def rhsSelect(gen: GenFrom, name: TermName) = {
val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
Select(rhs, name)
}

def checkMode(gen: GenFrom) =
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
else MatchCheck.None // refutable paterns were already eliminated in filter step

enums match {
case (gen: GenFrom) :: Nil =>
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _) :: _)) =>
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
case (GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) =>
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(pat)
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = pat :: pats
val vfrom1 = new IrrefutableGenFrom(makeTuple(allpats), rhs1)
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test))
val genFrom =
if (isIrrefutableGenFrom(gen)) new IrrefutableGenFrom(gen.pat, filtered)
else GenFrom(gen.pat, filtered)
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case _ =>
EmptyTree //may happen for erroneous input
Expand Down Expand Up @@ -1571,7 +1599,4 @@ object desugar {
collect(tree)
buf.toList
}

private class IrrefutableGenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile)
extends GenFrom(pat, expr)
}
23 changes: 16 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class GenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
Expand All @@ -116,6 +116,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
* `Positioned#checkPos` */
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)

/** An enum to control checking or filtering of patterns in GenFrom trees */
class GenCheckMode(val x: Int) extends AnyVal
object GenCheckMode {
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
val Check = new GenCheckMode(1) // check that pattern is irrefutable
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
}

// ----- Modifiers -----------------------------------------------------
/** Mod is intended to record syntactic information about modifiers, it's
* NOT a replacement of FlagSet.
Expand Down Expand Up @@ -525,9 +534,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source))
}
def GenFrom(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.GenFrom(pat, expr)(tree.source))
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match {
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree
case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source))
}
def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree
Expand Down Expand Up @@ -589,8 +598,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
cpy.ForYield(tree)(transform(enums), transform(expr))
case ForDo(enums, body) =>
cpy.ForDo(tree)(transform(enums), transform(body))
case GenFrom(pat, expr) =>
cpy.GenFrom(tree)(transform(pat), transform(expr))
case GenFrom(pat, expr, checkMode) =>
cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode)
case GenAlias(pat, expr) =>
cpy.GenAlias(tree)(transform(pat), transform(expr))
case ContextBounds(bounds, cxBounds) =>
Expand Down Expand Up @@ -644,7 +653,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this(this(x, enums), expr)
case ForDo(enums, body) =>
this(this(x, enums), body)
case GenFrom(pat, expr) =>
case GenFrom(pat, expr, _) =>
this(this(x, pat), expr)
case GenAlias(pat, expr) =>
this(this(x, pat), expr)
Expand Down
48 changes: 29 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1725,18 +1725,28 @@ object Parsers {
*/
def enumerator(): Tree =
if (in.token == IF) guard()
else if (in.token == CASE) generator()
else {
val pat = pattern1()
if (in.token == EQUALS) atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, expr()) }
else generatorRest(pat)
else generatorRest(pat, casePat = false)
}

/** Generator ::= Pattern `<-' Expr
/** Generator ::= [‘case’] Pattern `<-' Expr
*/
def generator(): Tree = generatorRest(pattern1())
def generator(): Tree = {
val casePat = if (in.token == CASE) { in.skipCASE(); true } else false
generatorRest(pattern1(), casePat)
}

def generatorRest(pat: Tree): GenFrom =
atSpan(startOffset(pat), accept(LARROW)) { GenFrom(pat, expr()) }
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
atSpan(startOffset(pat), accept(LARROW)) {
val checkMode =
if (casePat) GenCheckMode.FilterAlways
else if (ctx.settings.strict.value) GenCheckMode.Check
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
GenFrom(pat, expr(), checkMode)
}

/** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}')
* {nl} [`yield'] Expr
Expand All @@ -1749,16 +1759,20 @@ object Parsers {
else if (in.token == LPAREN) {
val lparenOffset = in.skipToken()
openParens.change(LPAREN, 1)
val pats = patternsOpt()
val pat =
if (in.token == RPAREN || pats.length > 1) {
wrappedEnums = false
accept(RPAREN)
openParens.change(LPAREN, -1)
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
val res =
if (in.token == CASE) enumerators()
else {
val pats = patternsOpt()
val pat =
if (in.token == RPAREN || pats.length > 1) {
wrappedEnums = false
accept(RPAREN)
openParens.change(LPAREN, -1)
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
}
else pats.head
generatorRest(pat, casePat = false) :: enumeratorsRest()
}
else pats.head
val res = generatorRest(pat) :: enumeratorsRest()
if (wrappedEnums) {
accept(RPAREN)
openParens.change(LPAREN, -1)
Expand Down Expand Up @@ -2640,11 +2654,7 @@ object Parsers {
*/
def enumCase(start: Offset, mods: Modifiers): DefTree = {
val mods1 = addMod(mods, atSpan(in.offset)(Mod.Enum())) | Case
accept(CASE)

in.adjustSepRegions(ARROW)
// Scanner thinks it is in a pattern match after seeing the `case`.
// We need to get it out of that mode by telling it we are past the `=>`
in.skipCASE()

atSpan(start, nameStart) {
val id = termIdent()
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ object Scanners {
case _ =>
}

/** Advance beyond a case token without marking the CASE in sepRegions.
* This method should be called to skip beyond CASE tokens that are
* not part of matches, i.e. no ARROW is expected after them.
*/
def skipCASE() = {
assert(token == CASE)
nextToken()
sepRegions = sepRegions.tail
}

/** Produce next token, filling TokenData fields of Scanner.
*/
def nextToken(): Unit = {
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
forText(enums, expr, keywordStr(" yield "))
case ForDo(enums, expr) =>
forText(enums, expr, keywordStr(" do "))
case GenFrom(pat, expr) =>
case GenFrom(pat, expr, checkMode) =>
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
toText(pat) ~ " <- " ~ toText(expr)
case GenAlias(pat, expr) =>
toText(pat) ~ " = " ~ toText(expr)
Expand Down
Loading