Skip to content

Commit 3a2190a

Browse files
authored
Merge pull request #12863 from dotty-staging/generalize-polyfuns
Drop implementation restriction for polymorphic functions
2 parents c03b577 + 797e261 commit 3a2190a

File tree

5 files changed

+35
-43
lines changed

5 files changed

+35
-43
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,12 +1632,12 @@ object desugar {
16321632
}
16331633
}
16341634

1635-
def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match {
1635+
def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match
16361636
case Parens(body1) =>
16371637
makePolyFunction(targs, body1)
16381638
case Block(Nil, body1) =>
16391639
makePolyFunction(targs, body1)
1640-
case Function(vargs, res) =>
1640+
case _ =>
16411641
assert(targs.nonEmpty)
16421642
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
16431643
val mods = body match {
@@ -1646,33 +1646,37 @@ object desugar {
16461646
}
16471647
val polyFunctionTpt = ref(defn.PolyFunctionType)
16481648
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1649-
if (ctx.mode.is(Mode.Type)) {
1649+
if ctx.mode.is(Mode.Type) then
16501650
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
16511651
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1652-
1653-
val applyVParams = vargs.zipWithIndex.map {
1654-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1655-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
1656-
}
1652+
val (res, applyVParamss) = body match
1653+
case Function(vargs, res) =>
1654+
( res,
1655+
vargs.zipWithIndex.map {
1656+
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1657+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
1658+
} :: Nil
1659+
)
1660+
case _ =>
1661+
(body, Nil)
16571662
RefinedTypeTree(polyFunctionTpt, List(
1658-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree)
1663+
DefDef(nme.apply, applyTParams :: applyVParamss, res, EmptyTree)
16591664
))
1660-
}
1661-
else {
1665+
else
16621666
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
16631667
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1664-
1665-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1666-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1667-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1668-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res))
1669-
))
1670-
}
1671-
case _ =>
1672-
// may happen for erroneous input. An error will already have been reported.
1673-
assert(ctx.reporter.errorsReported)
1674-
EmptyTree
1675-
}
1668+
val (res, applyVParamss) = body match
1669+
case Function(vargs, res) =>
1670+
( res,
1671+
vargs.asInstanceOf[List[ValDef]]
1672+
.map(varg => varg.withAddedFlags(mods.flags | Param))
1673+
:: Nil
1674+
)
1675+
case _ =>
1676+
(body, Nil)
1677+
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1678+
List(DefDef(nme.apply, applyTParams :: applyVParamss, TypeTree(), res))
1679+
))
16761680

16771681
// begin desugar
16781682

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
549549
* - otherwise, if T is a type parameter coming from Java, []Object
550550
* - otherwise, Object
551551
* - For a term ref p.x, the type <noprefix> # x.
552+
* - For a refined type scala.PolyFunction { def apply[...]: R }, scala.Function0
552553
* - For a refined type scala.PolyFunction { def apply[...](x_1, ..., x_N): R }, scala.FunctionN
553554
* - For a typeref scala.Any, scala.AnyVal, scala.Singleton, scala.Tuple, or scala.*: : |java.lang.Object|
554555
* - For a typeref scala.Unit, |scala.runtime.BoxedUnit|.
@@ -600,8 +601,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
600601
assert(refinedInfo.isInstanceOf[PolyType])
601602
val res = refinedInfo.resultType
602603
val paramss = res.paramNamess
603-
assert(paramss.length == 1)
604-
this(defn.FunctionType(paramss.head.length, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod))
604+
assert(paramss.length <= 1)
605+
val arity = if paramss.isEmpty then 0 else paramss.head.length
606+
this(defn.FunctionType(arity, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod))
605607
case tp: TypeProxy =>
606608
this(tp.underlying)
607609
case tp @ AndType(tp1, tp2) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,14 +1434,7 @@ object Parsers {
14341434
else if (in.token == ARROW) {
14351435
val arrowOffset = in.skipToken()
14361436
val body = toplevelTyp()
1437-
atSpan(start, arrowOffset) {
1438-
if (isFunction(body))
1439-
PolyFunction(tparams, body)
1440-
else {
1441-
syntaxError("Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
1442-
Ident(nme.ERROR.toTypeName)
1443-
}
1444-
}
1437+
atSpan(start, arrowOffset) { PolyFunction(tparams, body) }
14451438
}
14461439
else { accept(TLARROW); typ() }
14471440
}
@@ -1917,14 +1910,7 @@ object Parsers {
19171910
val tparams = typeParamClause(ParamOwner.TypeParam)
19181911
val arrowOffset = accept(ARROW)
19191912
val body = expr(location)
1920-
atSpan(start, arrowOffset) {
1921-
if (isFunction(body))
1922-
PolyFunction(tparams, body)
1923-
else {
1924-
syntaxError("Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)
1925-
errorTermTree
1926-
}
1927-
}
1913+
atSpan(start, arrowOffset) { PolyFunction(tparams, body) }
19281914
case _ =>
19291915
val saved = placeholderParams
19301916
placeholderParams = Nil

tests/neg/i2887b.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
trait A { type S[X[_] <: [_] => Any, Y[_]] <: [_] => Any; type I[_] } // error // error
2-
trait B { type S[X[_],Y[_]]; type I[_] <: [_] => Any } // error
1+
trait A { type S[X[_] <: [_] => Any, Y[_]] <: [_] => Any; type I[_] }
2+
trait B { type S[X[_],Y[_]]; type I[_] <: [_] => Any }
33
trait C { type M <: B }
44
trait D { type M >: A }
55

0 commit comments

Comments
 (0)