Skip to content

Properly handle by-name function types in REPL #18761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ElimByName.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ import dotty.tools.dotc.core.Names.Name
*
* Note also that the transformation applies only to types of parameters, not to other
* occurrences of ExprTypes. In particular, embedded occurrences in function types
* such as `(=> T) => U` are left as-is here (they are eliminated in erasure).
* such as `(=> T) => U` are left as-is here (they are eliminated in erasure),
* unless they are in a `((=> T) => U)#apply` term reference.
* Trying to convert these as well would mean traversing all the types, and that
* leads to cyclic reference errors in many cases. This can cause problems in that
* we might have sometimes a `() ?=> T` where a `=> T` is expected. To compensate,
Expand All @@ -60,7 +61,7 @@ class ElimByName extends MiniPhase, InfoTransformer:
override def description: String = ElimByName.description

override def runsAfterGroupsOf: Set[String] = Set(ExpandSAMs.name, ElimRepeated.name, RefChecks.name)
// - ExpanSAMs applied to partial functions creates methods that need
// - ExpandSAMs applied to partial functions creates methods that need
// to be fully defined before converting. Test case is pos/i9391.scala.
// - ElimByName needs to run in a group after ElimRepeated since ElimRepeated
// works on simple arguments but not converted closures, and it sees the arguments
Expand Down Expand Up @@ -116,10 +117,10 @@ class ElimByName extends MiniPhase, InfoTransformer:
else tree

override def transformIdent(tree: Ident)(using Context): Tree =
applyIfFunction(tree)
transformFunctionTypeApplyReferences(applyIfFunction(tree))

override def transformSelect(tree: Select)(using Context): Tree =
applyIfFunction(tree)
transformFunctionTypeApplyReferences(applyIfFunction(tree))

override def transformTypeApply(tree: TypeApply)(using Context): Tree = tree match {
case TypeApply(Select(_, nme.asInstanceOf_), arg :: Nil) =>
Expand All @@ -129,6 +130,25 @@ class ElimByName extends MiniPhase, InfoTransformer:
case _ => tree
}

/** Transform references to by-name function type apply
*
* ((=> T1) => R)#apply ---> ((() ?=> T1) => R)#apply
*/
private def transformFunctionTypeApplyReferences(tree: Tree)(using Context) =
val tpe1 = new TypeMap {
def apply(tp: Type): Type = tp match
case tp @ AppliedType(tycon, args) if defn.isFunctionType(tp) =>
val args1 = args.mapConserve {
case ExprType(tp) => defn.ByNameFunction(tp)
case arg => arg
}
tp.derivedAppliedType(tycon, args1)
case tp: TermRef if tp.termSymbol.name == nme.apply && defn.isFunctionClass(tp.termSymbol.owner) =>
mapOver(tp)
case tp => tp
}.apply(tree.tpe)
tree.withType(tpe1)

override def transformApply(tree: Apply)(using Context): Tree =
trace(s"transforming ${tree.show} at phase ${ctx.phase}", show = true) {

Expand Down
5 changes: 5 additions & 0 deletions compiler/test-resources/repl/i18756
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
scala> def f: ( => Int) => Int = i => i ; f(1)
def f: (=> Int) => Int
val res0: Int = 1
scala> f(1)
val res1: Int = 1
5 changes: 5 additions & 0 deletions compiler/test-resources/repl/i18756b
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
scala> def f: ( => Int) => ( => Int) => Int = i => j => i + j; f(1)(2)
def f: (=> Int) => (=> Int) => Int
val res0: Int = 3
scala> f(1)(2)
val res1: Int = 3