From 23300d08847cc6a5f684c02b53db2ac48db384cf Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 10 Jul 2024 17:13:40 +0200 Subject: [PATCH 1/6] Crude implementation of removing trailing unit-literal maps from for-comprehensions --- compiler/src/dotty/tools/dotc/Compiler.scala | 5 +- .../src/dotty/tools/dotc/ast/Desugar.scala | 20 +++++-- .../dotc/transform/localopt/DropForMap.scala | 56 +++++++++++++++++++ tests/run/map-unit-elim.check | 1 + tests/run/map-unit-elim.scala | 34 +++++++++++ 5 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala create mode 100644 tests/run/map-unit-elim.check create mode 100644 tests/run/map-unit-elim.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 8f22a761e790..6ef49c786e44 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -8,7 +8,7 @@ import parsing.Parser import Phases.Phase import transform.* import backend.jvm.{CollectSuperCalls, GenBCode} -import localopt.StringInterpolatorOpt +import localopt.{StringInterpolatorOpt, DropForMap} /** The central class of the dotc compiler. The job of a compiler is to create * runs, which process given `phases` in a given `rootContext`. @@ -90,7 +90,8 @@ class Compiler { new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats - new DropBreaks) :: // Optimize local Break throws by rewriting them + new DropBreaks, // Optimize local Break throws by rewriting them + new DropForMap) :: // Drop unused trailing map calls in for comprehensions List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_` new InlinePatterns, // Remove placeholders of inlined patterns diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index ec65224ac93d..b3f41f5c29d9 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -64,6 +64,11 @@ object desugar { */ val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + /** An attachment key to indicate that an Apply is created as a last `map` + * scall in a for-comprehension. + */ + val TrailingForMap: Property.Key[Unit] = Property.StickyKey() + /** What static check should be applied to a Match? */ enum MatchCheck { case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom @@ -2149,11 +2154,13 @@ object desugar { enums match { case Nil if betterForsEnabled => body case (gen: GenFrom) :: Nil => + val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) if betterForsEnabled - && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) - then gen.expr // avoid a redundant map with identity - else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + // && deepEquals(gen.pat, body) + then + aply.putAttachment(TrailingForMap, ()) + aply case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) @@ -2164,7 +2171,10 @@ object desugar { val selectName = if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName else mapName - Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + if selectName == mapName then + aply.pushAttachment(TrailingForMap, ()) + aply case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) val pats = valeqs map { case GenAlias(pat, _) => pat } diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala new file mode 100644 index 000000000000..727f2b472cdd --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -0,0 +1,56 @@ +package dotty.tools.dotc +package transform.localopt + +import scala.language.unsafeNulls + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Decorators.* +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.transform.MegaPhase.MiniPhase +import dotty.tools.dotc.typer.ConstFold +import dotty.tools.dotc.ast.desugar +import scala.util.chaining.* + +class DropForMap extends MiniPhase: + import tpd.* + + override def phaseName: String = DropForMap.name + + override def description: String = DropForMap.description + + override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = + if !tree.hasAttachment(desugar.TrailingForMap) then tree.tap(_.removeAttachment(desugar.TrailingForMap)) + else tree match + case Apply(MapCall(f), List(Lambda(List(param), body))) + if isEssentiallyUnitLiteral(param, body) && param.tpt.tpe.isRef(defn.UnitClass) => + f + case _ => + tree.tap(_.removeAttachment(desugar.TrailingForMap)) + + private object Lambda: + def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = + tree match + case Block(List(defdef: DefDef), Closure(Nil, ref, _)) if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => + Some((defdef.termParamss.flatten, defdef.rhs)) + case _ => None + + private object MapCall: + def unapply(tree: Tree)(using Context): Option[Tree] = tree match + case Select(f, nme.map) => Some(f) + case Apply(fn, _) => unapply(fn) + case TypeApply(fn, _) => unapply(fn) + case _ => None + + def isEssentiallyUnitLiteral(param: ValDef, tree: Tree)(using Context): Boolean = tree match + case Literal(Constant(())) => true + case Match(scrutinee, List(CaseDef(_, EmptyTree, body))) => isEssentiallyUnitLiteral(param, body) + case Block(Nil, expr) => isEssentiallyUnitLiteral(param, expr) + case _ => false + +object DropForMap: + val name: String = "dropForMap" + val description: String = "Drop unused trailing map calls in for comprehensions" diff --git a/tests/run/map-unit-elim.check b/tests/run/map-unit-elim.check new file mode 100644 index 000000000000..1b7b66a37884 --- /dev/null +++ b/tests/run/map-unit-elim.check @@ -0,0 +1 @@ +MySome(()) diff --git a/tests/run/map-unit-elim.scala b/tests/run/map-unit-elim.scala new file mode 100644 index 000000000000..cdefc7a8e839 --- /dev/null +++ b/tests/run/map-unit-elim.scala @@ -0,0 +1,34 @@ +import scala.language.experimental.betterFors + +class myOptionPackage(doOnMap: => Unit) { + sealed trait MyOption[+A] { + def map[B](f: A => B): MyOption[B] = this match { + case MySome(x) => { + doOnMap + MySome(f(x)) + } + case MyNone => MyNone + } + def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match { + case MySome(x) => f(x) + case MyNone => MyNone + } + } + case class MySome[A](x: A) extends MyOption[A] + case object MyNone extends MyOption[Nothing] +} + +object Test extends App { + + val myOption = new myOptionPackage(println("map called")) + + import myOption.* + + val z = for { + a <- MySome(1) + b <- MySome(()) + } yield () + + println(z) + +} From 961a3cda45f17b7fd17a709cdd1664e8b59d2940 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 17 Feb 2025 16:57:54 +0100 Subject: [PATCH 2/6] Working better-fors fix --- compiler/src/dotty/tools/dotc/Compiler.scala | 6 +- .../src/dotty/tools/dotc/ast/Desugar.scala | 8 +-- .../dotc/transform/localopt/DropForMap.scala | 67 ++++++++++++++++--- tests/pos/better-fors-i21804.scala | 13 ++++ tests/run/better-fors-map-elim.check | 4 ++ ...-elim.scala => better-fors-map-elim.scala} | 32 +++++++-- tests/run/map-unit-elim.check | 1 - 7 files changed, 106 insertions(+), 25 deletions(-) create mode 100644 tests/pos/better-fors-i21804.scala create mode 100644 tests/run/better-fors-map-elim.check rename tests/run/{map-unit-elim.scala => better-fors-map-elim.scala} (52%) delete mode 100644 tests/run/map-unit-elim.check diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 6ef49c786e44..6aab7d54d59e 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -68,7 +68,8 @@ class Compiler { new InlineVals, // Check right hand-sides of an `inline val`s new ExpandSAMs, // Expand single abstract method closures to anonymous classes new ElimRepeated, // Rewrite vararg parameters and arguments - new RefChecks) :: // Various checks mostly related to abstract members and overriding + new RefChecks, // Various checks mostly related to abstract members and overriding + new DropForMap) :: // Drop unused trailing map calls in for comprehensions List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file List(new init.Checker) :: // Check initialization of objects List(new ProtectedAccessors, // Add accessors for protected members @@ -90,8 +91,7 @@ class Compiler { new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats - new DropBreaks, // Optimize local Break throws by rewriting them - new DropForMap) :: // Drop unused trailing map calls in for comprehensions + new DropBreaks) :: // Optimize local Break throws by rewriting them List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_` new InlinePatterns, // Remove placeholders of inlined patterns diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index b3f41f5c29d9..d591341dcae6 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1971,14 +1971,8 @@ object desugar { * * 3. * - * for (P <- G) yield P ==> G - * - * If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter. - * * for (P <- G) yield E ==> G.map (P => E) * - * Otherwise - * * 4. * * for (P_1 <- G_1; P_2 <- G_2; ...) ... @@ -2157,7 +2151,7 @@ object desugar { val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) if betterForsEnabled && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - // && deepEquals(gen.pat, body) + && deepEquals(gen.pat, body) then aply.putAttachment(TrailingForMap, ()) aply diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index 727f2b472cdd..87a38a7135bf 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -14,20 +14,22 @@ import dotty.tools.dotc.transform.MegaPhase.MiniPhase import dotty.tools.dotc.typer.ConstFold import dotty.tools.dotc.ast.desugar import scala.util.chaining.* +import tpd.* class DropForMap extends MiniPhase: - import tpd.* + import DropForMap.* + import Binder.* override def phaseName: String = DropForMap.name override def description: String = DropForMap.description override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = - if !tree.hasAttachment(desugar.TrailingForMap) then tree.tap(_.removeAttachment(desugar.TrailingForMap)) + if !tree.hasAttachment(desugar.TrailingForMap) then tree else tree match - case Apply(MapCall(f), List(Lambda(List(param), body))) - if isEssentiallyUnitLiteral(param, body) && param.tpt.tpe.isRef(defn.UnitClass) => - f + case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) + if canDropMap(Single(param), body) && f.tpe =:= aply.tpe => // make sure that the type of the expression won't change + f // drop the map call case _ => tree.tap(_.removeAttachment(desugar.TrailingForMap)) @@ -45,12 +47,57 @@ class DropForMap extends MiniPhase: case TypeApply(fn, _) => unapply(fn) case _ => None - def isEssentiallyUnitLiteral(param: ValDef, tree: Tree)(using Context): Boolean = tree match - case Literal(Constant(())) => true - case Match(scrutinee, List(CaseDef(_, EmptyTree, body))) => isEssentiallyUnitLiteral(param, body) - case Block(Nil, expr) => isEssentiallyUnitLiteral(param, expr) - case _ => false + /** We can drop the map call if: + * - it is a Unit literal + * - is an identity function (i.e. the last pattern is the same as the result) + */ + private def canDropMap(params: Binder, tree: Tree)(using Context): Boolean = tree match + case Literal(Constant(())) => params match + case Single(bind) => bind.symbol.info.isRef(defn.UnitClass) + case _ => false + case ident: Ident => params match + case Single(bind) => bind.symbol == ident.symbol + case _ => false + case tree: Apply if tree.tpe.typeSymbol.derivesFrom(defn.TupleClass) => params match + case Tuple(binds) => tree.args.zip(binds).forall((arg, param) => canDropMap(param, arg)) + case _ => false + case Match(scrutinee, List(CaseDef(pat, EmptyTree, body))) => + val newParams = newParamsFromMatch(params, scrutinee, pat) + canDropMap(newParams, body) + case Block(Nil, expr) => canDropMap(params, expr) + case _ => + false + + /** Extract potentially new parameters from a match expression + */ + private def newParamsFromMatch(params: Binder, scrutinee: Tree, pat: Tree)(using Context): Binder = + def extractTraverse(pats: List[Tree]): Option[List[Binder]] = pats match + case Nil => Some(List.empty) + case pat :: pats => + extractBinders(pat).map(_ +: extractTraverse(pats).get) + def extractBinders(pat: Tree): Option[Binder] = pat match + case bind: Bind => Some(Single(bind)) + case tree @ UnApply(fun, implicits, pats) + if implicits.isEmpty && tree.tpe.finalResultType.dealias.typeSymbol.derivesFrom(defn.TupleClass) => + extractTraverse(pats).map(Tuple.apply) + case _ => None + + params match + case Single(bind) if scrutinee.symbol == bind.symbol => + pat match + case bind: Bind => Single(bind) + case tree @ UnApply(fun, implicits, pats) if implicits.isEmpty => + val unapplied = tree.tpe.finalResultType.dealias.typeSymbol + if unapplied.derivesFrom(defn.TupleClass) then + extractTraverse(pats).map(Tuple.apply).getOrElse(params) + else params + case _ => params + case _ => params object DropForMap: val name: String = "dropForMap" val description: String = "Drop unused trailing map calls in for comprehensions" + + private enum Binder: + case Single(bind: NamedDefTree) + case Tuple(binds: List[Binder]) diff --git a/tests/pos/better-fors-i21804.scala b/tests/pos/better-fors-i21804.scala new file mode 100644 index 000000000000..7c8c753bf7c3 --- /dev/null +++ b/tests/pos/better-fors-i21804.scala @@ -0,0 +1,13 @@ +import scala.language.experimental.betterFors + +case class Container[A](val value: A) { + def map[B](f: A => B): Container[B] = Container(f(value)) +} + +sealed trait Animal +case class Dog() extends Animal + +def opOnDog(dog: Container[Dog]): Container[Animal] = + for + v <- dog + yield v diff --git a/tests/run/better-fors-map-elim.check b/tests/run/better-fors-map-elim.check new file mode 100644 index 000000000000..0ef3447a47c4 --- /dev/null +++ b/tests/run/better-fors-map-elim.check @@ -0,0 +1,4 @@ +MySome(()) +MySome(2) +MySome((2,3)) +MySome((2,(3,4))) diff --git a/tests/run/map-unit-elim.scala b/tests/run/better-fors-map-elim.scala similarity index 52% rename from tests/run/map-unit-elim.scala rename to tests/run/better-fors-map-elim.scala index cdefc7a8e839..c68acbe44789 100644 --- a/tests/run/map-unit-elim.scala +++ b/tests/run/better-fors-map-elim.scala @@ -1,6 +1,6 @@ import scala.language.experimental.betterFors -class myOptionPackage(doOnMap: => Unit) { +class myOptionModule(doOnMap: => Unit) { sealed trait MyOption[+A] { def map[B](f: A => B): MyOption[B] = this match { case MySome(x) => { @@ -16,19 +16,43 @@ class myOptionPackage(doOnMap: => Unit) { } case class MySome[A](x: A) extends MyOption[A] case object MyNone extends MyOption[Nothing] + object MyOption { + def apply[A](x: A): MyOption[A] = MySome(x) + } } object Test extends App { - val myOption = new myOptionPackage(println("map called")) + val myOption = new myOptionModule(println("map called")) import myOption.* val z = for { - a <- MySome(1) - b <- MySome(()) + a <- MyOption(1) + b <- MyOption(()) } yield () println(z) + val z2 = for { + a <- MyOption(1) + b <- MyOption(2) + } yield b + + println(z2) + + val z3 = for { + a <- MyOption(1) + (b, c) <- MyOption((2, 3)) + } yield (b, c) + + println(z3) + + val z4 = for { + a <- MyOption(1) + (b, (c, d)) <- MyOption((2, (3, 4))) + } yield (b, (c, d)) + + println(z4) + } diff --git a/tests/run/map-unit-elim.check b/tests/run/map-unit-elim.check deleted file mode 100644 index 1b7b66a37884..000000000000 --- a/tests/run/map-unit-elim.check +++ /dev/null @@ -1 +0,0 @@ -MySome(()) From 1807f316122745bb4908854e0f638447459676ff Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 17 Feb 2025 17:11:51 +0100 Subject: [PATCH 3/6] Simplify betterFors fix a lot --- .../src/dotty/tools/dotc/ast/Desugar.scala | 4 +- .../dotc/transform/localopt/DropForMap.scala | 70 ++----------------- 2 files changed, 9 insertions(+), 65 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index d591341dcae6..f41736f14afd 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -2151,7 +2151,7 @@ object desugar { val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) if betterForsEnabled && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) + && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) then aply.putAttachment(TrailingForMap, ()) aply @@ -2166,7 +2166,7 @@ object desugar { if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName else mapName val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) - if selectName == mapName then + if selectName == mapName && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) then aply.pushAttachment(TrailingForMap, ()) aply case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index 87a38a7135bf..1ae58def69e8 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -1,42 +1,37 @@ package dotty.tools.dotc package transform.localopt -import scala.language.unsafeNulls - -import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.core.Decorators.* -import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.* import dotty.tools.dotc.core.StdNames.* import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.transform.MegaPhase.MiniPhase -import dotty.tools.dotc.typer.ConstFold import dotty.tools.dotc.ast.desugar -import scala.util.chaining.* -import tpd.* class DropForMap extends MiniPhase: import DropForMap.* - import Binder.* override def phaseName: String = DropForMap.name override def description: String = DropForMap.description - override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = + override def transformApply(tree: Apply)(using Context): Tree = if !tree.hasAttachment(desugar.TrailingForMap) then tree else tree match case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) - if canDropMap(Single(param), body) && f.tpe =:= aply.tpe => // make sure that the type of the expression won't change + if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change f // drop the map call case _ => - tree.tap(_.removeAttachment(desugar.TrailingForMap)) + tree.removeAttachment(desugar.TrailingForMap) + tree private object Lambda: def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match - case Block(List(defdef: DefDef), Closure(Nil, ref, _)) if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => + case Block(List(defdef: DefDef), Closure(Nil, ref, _)) + if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => Some((defdef.termParamss.flatten, defdef.rhs)) case _ => None @@ -47,57 +42,6 @@ class DropForMap extends MiniPhase: case TypeApply(fn, _) => unapply(fn) case _ => None - /** We can drop the map call if: - * - it is a Unit literal - * - is an identity function (i.e. the last pattern is the same as the result) - */ - private def canDropMap(params: Binder, tree: Tree)(using Context): Boolean = tree match - case Literal(Constant(())) => params match - case Single(bind) => bind.symbol.info.isRef(defn.UnitClass) - case _ => false - case ident: Ident => params match - case Single(bind) => bind.symbol == ident.symbol - case _ => false - case tree: Apply if tree.tpe.typeSymbol.derivesFrom(defn.TupleClass) => params match - case Tuple(binds) => tree.args.zip(binds).forall((arg, param) => canDropMap(param, arg)) - case _ => false - case Match(scrutinee, List(CaseDef(pat, EmptyTree, body))) => - val newParams = newParamsFromMatch(params, scrutinee, pat) - canDropMap(newParams, body) - case Block(Nil, expr) => canDropMap(params, expr) - case _ => - false - - /** Extract potentially new parameters from a match expression - */ - private def newParamsFromMatch(params: Binder, scrutinee: Tree, pat: Tree)(using Context): Binder = - def extractTraverse(pats: List[Tree]): Option[List[Binder]] = pats match - case Nil => Some(List.empty) - case pat :: pats => - extractBinders(pat).map(_ +: extractTraverse(pats).get) - def extractBinders(pat: Tree): Option[Binder] = pat match - case bind: Bind => Some(Single(bind)) - case tree @ UnApply(fun, implicits, pats) - if implicits.isEmpty && tree.tpe.finalResultType.dealias.typeSymbol.derivesFrom(defn.TupleClass) => - extractTraverse(pats).map(Tuple.apply) - case _ => None - - params match - case Single(bind) if scrutinee.symbol == bind.symbol => - pat match - case bind: Bind => Single(bind) - case tree @ UnApply(fun, implicits, pats) if implicits.isEmpty => - val unapplied = tree.tpe.finalResultType.dealias.typeSymbol - if unapplied.derivesFrom(defn.TupleClass) then - extractTraverse(pats).map(Tuple.apply).getOrElse(params) - else params - case _ => params - case _ => params - object DropForMap: val name: String = "dropForMap" val description: String = "Drop unused trailing map calls in for comprehensions" - - private enum Binder: - case Single(bind: NamedDefTree) - case Tuple(binds: List[Binder]) From 867fb354bd36ea92d112baeefc7cd9941d507b29 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 17 Feb 2025 17:22:10 +0100 Subject: [PATCH 4/6] docs and comments --- .../dotty/tools/dotc/transform/localopt/DropForMap.scala | 7 +++++++ docs/_docs/reference/experimental/better-fors.md | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala index 1ae58def69e8..f7594f041204 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -10,6 +10,13 @@ import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.transform.MegaPhase.MiniPhase import dotty.tools.dotc.ast.desugar +/** Drop unused trailing map calls in for comprehensions. + * We can drop the map call if: + * - it won't change the type of the expression, and + * - the function is an identity function or a const function to unit. + * + * The latter condition is checked in [[Desugar.scala#makeFor]] + */ class DropForMap extends MiniPhase: import DropForMap.* diff --git a/docs/_docs/reference/experimental/better-fors.md b/docs/_docs/reference/experimental/better-fors.md index a4c42c9fb380..4f910259aab2 100644 --- a/docs/_docs/reference/experimental/better-fors.md +++ b/docs/_docs/reference/experimental/better-fors.md @@ -60,7 +60,7 @@ Additionally this extension changes the way `for`-comprehensions are desugared. This change makes the desugaring more intuitive and avoids unnecessary `map` calls, when an alias is not followed by a guard. 2. **Avoiding Redundant `map` Calls**: - When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. but th eequality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. + When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. But the equality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. There is also a special case for dropping the `map`, if its body is a constant function, that returns `()` (`Unit` constant). **Current Desugaring**: ```scala for { From 213a3fd4ece7206bb53545351e72bdaf44c53594 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 18 Feb 2025 11:43:42 +0100 Subject: [PATCH 5/6] Use custom print as a workaround for ScalaJS tests --- tests/run/better-fors-map-elim.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/run/better-fors-map-elim.scala b/tests/run/better-fors-map-elim.scala index c68acbe44789..653984bc8e28 100644 --- a/tests/run/better-fors-map-elim.scala +++ b/tests/run/better-fors-map-elim.scala @@ -27,32 +27,38 @@ object Test extends App { import myOption.* + def portablePrintMyOption(opt: MyOption[Any]): Unit = + if opt == MySome(()) then + println("MySome(())") + else + println(opt) + val z = for { a <- MyOption(1) b <- MyOption(()) } yield () - println(z) + portablePrintMyOption(z) val z2 = for { a <- MyOption(1) b <- MyOption(2) } yield b - println(z2) + portablePrintMyOption(z2) val z3 = for { a <- MyOption(1) (b, c) <- MyOption((2, 3)) } yield (b, c) - println(z3) + portablePrintMyOption(z3) val z4 = for { a <- MyOption(1) (b, (c, d)) <- MyOption((2, (3, 4))) } yield (b, (c, d)) - println(z4) + portablePrintMyOption(z4) } From 86645e685d326065e2b385e071ed99d2dcf1e115 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Fri, 21 Feb 2025 10:54:29 +0100 Subject: [PATCH 6/6] Review fixes for betterFors fix --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index f41736f14afd..2d0d6d25b190 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -2145,15 +2145,19 @@ object desugar { case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) case _ => false + def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit = + if betterForsEnabled + && selectName == mapName + && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) + then + aply.putAttachment(TrailingForMap, ()) + enums match { case Nil if betterForsEnabled => body case (gen: GenFrom) :: Nil => val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) - if betterForsEnabled - && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) - then - aply.putAttachment(TrailingForMap, ()) + markTrailingMap(aply, gen, mapName) aply case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) @@ -2166,8 +2170,7 @@ object desugar { if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName else mapName val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) - if selectName == mapName && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) then - aply.pushAttachment(TrailingForMap, ()) + markTrailingMap(aply, gen, selectName) aply case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])