Skip to content

Commit 47eb542

Browse files
committed
Refactor refined function logic
1 parent 6e45dd7 commit 47eb542

27 files changed

+178
-174
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
954954
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
955955
def isStructuralTermSelect(tree: Select) =
956956
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
957-
case defn.PolyFunctionOf(_) =>
957+
case defn.FunctionOf(_) =>
958958
false
959959
case RefinedType(parent, rname, rinfo) =>
960960
rname == tree.name || hasRefinement(parent)

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,13 +1152,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11521152

11531153
def etaExpandCFT(using Context): Tree =
11541154
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1155-
case defn.ContextFunctionType(argTypes, resType, _) =>
1156-
val anonFun = newAnonFun(
1157-
ctx.owner,
1158-
MethodType.companion(isContextual = true)(argTypes, resType),
1159-
coord = ctx.owner.coord)
1155+
case defn.FunctionOf(mt: MethodType) if mt.isContextualMethod && !mt.isResultDependent => // TODO handle result-dependent functions?
1156+
val anonFun = newAnonFun(ctx.owner, mt, coord = ctx.owner.coord)
11601157
def lambdaBody(refss: List[List[Tree]]) =
1161-
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
1158+
expand(target.select(nme.apply).appliedToArgss(refss), mt.resType)(
11621159
using ctx.withOwner(anonFun))
11631160
Closure(anonFun, lambdaBody)
11641161
case _ =>

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ object CaptureSet:
881881
++ (recur(rinfo.resType) // add capture set of result
882882
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
883883
case tpd @ AppliedType(tycon, args) =>
884-
if followResult && defn.isNonRefinedFunction(tpd) then
884+
if followResult && defn.isFunctionNType(tpd) then
885885
recur(args.last)
886886
// must be (pure) FunctionN type since ImpureFunctions have already
887887
// been eliminated in selector's dealias. Use capture set of result.

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
195195
capt.println(i"solving $t")
196196
refs.solve()
197197
traverse(parent)
198-
case t @ defn.RefinedFunctionOf(rinfo) =>
198+
case defn.RefinedFunctionOf(rinfo) =>
199199
traverse(rinfo)
200200
case tp: TypeVar =>
201201
case tp: TypeRef =>
@@ -408,10 +408,10 @@ class CheckCaptures extends Recheck, SymTransformer:
408408
else if meth == defn.Caps_unsafeUnbox then
409409
mapArgUsing(_.forceBoxStatus(false))
410410
else if meth == defn.Caps_unsafeBoxFunArg then
411-
mapArgUsing:
412-
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
413-
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)
414-
411+
mapArgUsing: tp =>
412+
val defn.FunctionOf(mt: MethodType) = tp.dealias: @unchecked
413+
mt.derivedLambdaType(resType = mt.resType.forceBoxStatus(true))
414+
.toFunctionType()
415415
else
416416
super.recheckApply(tree, pt) match
417417
case appType @ CapturingType(appType1, refs) =>
@@ -502,8 +502,9 @@ class CheckCaptures extends Recheck, SymTransformer:
502502
block match
503503
case closureDef(mdef) =>
504504
pt.dealias match
505-
case defn.FunctionOf(ptformals, _, _)
506-
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
505+
case defn.FunctionOf(mt0: MethodType)
506+
if mt0.paramInfos.nonEmpty && mt0.paramInfos.forall(_.captureSet.isAlwaysEmpty) =>
507+
val ptformals = mt0.paramInfos
507508
// Redo setup of the anonymous function so that formal parameters don't
508509
// get capture sets. This is important to avoid false widenings to `cap`
509510
// when taking the base type of the actual closures's dependent function
@@ -707,10 +708,12 @@ class CheckCaptures extends Recheck, SymTransformer:
707708
val eparent1 = recur(eparent)
708709
if eparent1 eq eparent then expected
709710
else CapturingType(eparent1, refs, boxed = expected0.isBoxed)
710-
case expected @ defn.FunctionOf(args, resultType, isContextual)
711-
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
712-
val expected1 = toDepFun(args, resultType, isContextual)
713-
expected1
711+
case defn.FunctionOf(mt: MethodType) =>
712+
actual.dealias match
713+
case defn.FunctionOf(mt2: MethodType) if mt2.isResultDependent =>
714+
mt.toFunctionType(alwaysDependent = true)
715+
case _ =>
716+
expected
714717
case _ =>
715718
expected
716719
recur(expected)
@@ -781,9 +784,8 @@ class CheckCaptures extends Recheck, SymTransformer:
781784

782785
try
783786
val (eargs, eres) = expected.dealias.stripCapturing match
784-
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
785787
case expected: MethodType => (expected.paramInfos, expected.resType)
786-
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
788+
case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType)
787789
case _ => (aargs.map(_ => WildcardType), WildcardType)
788790
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
789791
val ares1 = adapt(ares, eres, covariant)
@@ -842,7 +844,7 @@ class CheckCaptures extends Recheck, SymTransformer:
842844

843845
// Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation
844846
val (styp1, leaked) = styp match {
845-
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
847+
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
846848
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
847849
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
848850
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ extends tpd.TreeTraverser:
4949
def recur(tp: Type): Type = tp.dealias match
5050
case tp @ CapturingType(parent, refs) if !tp.isBoxed =>
5151
tp.boxed
52-
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
52+
case tp1 @ AppliedType(tycon, args) if defn.isFunctionNType(tp1) =>
5353
val res = args.last
5454
val boxedRes = recur(res)
5555
if boxedRes eq res then tp
@@ -129,7 +129,7 @@ extends tpd.TreeTraverser:
129129
apply(parent)
130130
case tp @ AppliedType(tycon, args) =>
131131
val tycon1 = this(tycon)
132-
if defn.isNonRefinedFunction(tp) then
132+
if defn.isFunctionNType(tp) then
133133
// Convert toplevel generic function types to dependent functions
134134
if !defn.isFunctionSymbol(tp.typeSymbol) && (tp.dealias ne tp) then
135135
// This type is a function after dealiasing, so we dealias and recurse.
@@ -197,7 +197,7 @@ extends tpd.TreeTraverser:
197197
val mt = ContextualMethodType(paramName :: Nil)(
198198
_ => paramType :: Nil,
199199
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
200-
val fntpe = defn.PolyFunctionOf(mt)
200+
val fntpe = mt.toFunctionType()
201201
if !encl.isEmpty && isLast then
202202
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
203203
CapturingType(fntpe, cs, boxed = false)

compiler/src/dotty/tools/dotc/cc/Synthetics.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ object Synthetics:
174174
val (et: ExprType) = symd.info: @unchecked
175175
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
176176
def mapFinalResult(tp: Type, f: Type => Type): Type =
177-
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
177+
val defn.FunctionNOf(args, res, isContextual) = tp: @unchecked
178178
if defn.isFunctionNType(res) then
179-
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
179+
defn.FunctionNOf(args, mapFinalResult(res, f), isContextual)
180180
else
181181
f(tp)
182182
val resType1 =

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,23 +1109,34 @@ class Definitions {
11091109
sym.owner.linkedClass.typeRef
11101110

11111111
object FunctionOf {
1112+
/** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
1113+
* Extracts the method type type and apply info.
1114+
*/
1115+
def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
1116+
ft match
1117+
case RefinedFunctionOf(mt) => Some(mt)
1118+
case FunctionNOf(argTypes, resultType, isContextual) =>
1119+
val methodType = if isContextual then ContextualMethodType else MethodType
1120+
Some(methodType(argTypes, resultType))
1121+
case _ => None
1122+
}
1123+
}
1124+
1125+
object FunctionNOf {
1126+
/** Create a `FunctionN` or `ContextFunctionN` type applied to the arguments and result type */
11121127
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
1113-
val mt = MethodType.companion(isContextual, false)(args, resultType)
1114-
if mt.hasErasedParams then
1115-
RefinedType(PolyFunctionClass.typeRef, nme.apply, mt)
1116-
else
1117-
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
1118-
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
1119-
ft.dealias match
1120-
case PolyFunctionOf(mt: MethodType) =>
1121-
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
1122-
case dft =>
1123-
val tsym = dft.typeSymbol
1124-
if isFunctionSymbol(tsym) && ft.isRef(tsym) then
1125-
val targs = dft.argInfos
1126-
if (targs.isEmpty) None
1127-
else Some(targs.init, targs.last, tsym.name.isContextFunction)
1128-
else None
1128+
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
1129+
1130+
/** Matches a (possibly aliased) `FunctionN[...]` or `ContextFunctionN[...]`.
1131+
* Extracts the list of function argument types, the result type and whether function is contextual.
1132+
*/
1133+
def unapply(tpe: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
1134+
val tsym = tpe.typeSymbol
1135+
if isFunctionSymbol(tsym) && tpe.isRef(tsym) then
1136+
val targs = tpe.argInfos
1137+
if (targs.isEmpty) None
1138+
else Some(targs.init, targs.last, tsym.name.isContextFunction)
1139+
else None
11291140
}
11301141
}
11311142

@@ -1165,6 +1176,7 @@ class Definitions {
11651176
def isValidMethodType(info: Type) = info match
11661177
case info: MethodType =>
11671178
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
1179+
&& !info.isParamDependent
11681180
case _ => false
11691181
info match
11701182
case info: PolyType => isValidMethodType(info.resType)
@@ -1731,26 +1743,20 @@ class Definitions {
17311743

17321744
def isProductSubType(tp: Type)(using Context): Boolean = tp.derivesFrom(ProductClass)
17331745

1734-
/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ContextFunctionN
1735-
* instance?
1746+
/** Returns whether `tp` is an instance or a refined instance of:
1747+
* - scala.FunctionN
1748+
* - scala.ContextFunctionN
17361749
*/
1737-
def isNonRefinedFunction(tp: Type)(using Context): Boolean =
1738-
val arity = functionArity(tp)
1739-
val sym = tp.dealias.typeSymbol
1750+
def isFunctionNType(tp: Type)(using Context): Boolean =
1751+
val tp1 = tp.dropDependentRefinement
1752+
val arity = functionArity(tp1)
1753+
val sym = tp1.dealias.typeSymbol
17401754

17411755
arity >= 0
17421756
&& isFunctionClass(sym)
1743-
&& tp.isRef(
1757+
&& tp1.isRef(
17441758
FunctionType(arity, sym.name.isContextFunction).typeSymbol,
17451759
skipRefined = false)
1746-
end isNonRefinedFunction
1747-
1748-
/** Returns whether `tp` is an instance or a refined instance of:
1749-
* - scala.FunctionN
1750-
* - scala.ContextFunctionN
1751-
*/
1752-
def isFunctionNType(tp: Type)(using Context): Boolean =
1753-
isNonRefinedFunction(tp.dropDependentRefinement)
17541760

17551761
/** Returns whether `tp` is an instance or a refined instance of:
17561762
* - scala.FunctionN
@@ -1873,24 +1879,6 @@ class Definitions {
18731879
def isContextFunctionType(tp: Type)(using Context): Boolean =
18741880
asContextFunctionType(tp).exists
18751881

1876-
/** An extractor for context function types `As ?=> B`, possibly with
1877-
* dependent refinements. Optionally returns a triple consisting of the argument
1878-
* types `As`, the result type `B` and a whether the type is an erased context function.
1879-
*/
1880-
object ContextFunctionType:
1881-
def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] =
1882-
if ctx.erasedTypes then
1883-
atPhase(erasurePhase)(unapply(tp))
1884-
else
1885-
asContextFunctionType(tp) match
1886-
case PolyFunctionOf(mt: MethodType) =>
1887-
Some((mt.paramInfos, mt.resType, mt.erasedParams))
1888-
case tp1 if tp1.exists =>
1889-
val args = tp1.functionArgInfos
1890-
val erasedParams = List.fill(functionArity(tp1)) { false }
1891-
Some((args.init, args.last, erasedParams))
1892-
case _ => None
1893-
18941882
/** A whitelist of Scala-2 classes that are known to be pure */
18951883
def isAssuredNoInits(sym: Symbol): Boolean =
18961884
(sym `eq` SomeClass) || isTupleClass(sym)

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ object TypeErasure {
567567
functionType(info.resultType)
568568
case info: MethodType =>
569569
assert(!info.resultType.isInstanceOf[MethodicType])
570-
defn.FunctionType(n = info.erasedParams.count(_ == false))
570+
defn.FunctionType(n = info.nonErasedParamCount)
571571
}
572572
erasure(functionType(applyInfo))
573573
}
@@ -933,7 +933,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
933933
case tp: TermRef =>
934934
sigName(underlyingOfTermRef(tp))
935935
case ExprType(rt) =>
936-
sigName(defn.FunctionOf(Nil, rt))
936+
sigName(defn.FunctionNOf(Nil, rt))
937937
case tp: TypeVar if !tp.isInstantiated =>
938938
tpnme.Uninstantiated
939939
case tp @ defn.PolyFunctionOf(_) =>

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,7 +1519,7 @@ object Types {
15191519

15201520
/** Dealias, and if result is a dependent function type, drop the `apply` refinement. */
15211521
final def dropDependentRefinement(using Context): Type = dealias match {
1522-
case RefinedType(parent, nme.apply, mt) if defn.isNonRefinedFunction(parent) => parent
1522+
case RefinedType(parent, nme.apply, mt) if defn.isFunctionNType(parent) => parent
15231523
case tp => tp
15241524
}
15251525

@@ -1890,19 +1890,30 @@ object Types {
18901890
case res: MethodType => res.toFunctionType(isJava)
18911891
case res => res
18921892
}
1893-
defn.FunctionOf(
1893+
defn.FunctionNOf(
18941894
mt.paramInfos.mapConserve(_.translateFromRepeated(toArray = isJava)),
18951895
result1, isContextual)
18961896
if mt.hasErasedParams then
1897-
defn.PolyFunctionOf(mt)
1897+
assert(isValidPolyFunctionInfo(mt), s"Not a valid PolyFunction refinement: $mt")
1898+
RefinedType(defn.PolyFunctionType, nme.apply, mt)
18981899
else if alwaysDependent || mt.isResultDependent then
18991900
RefinedType(nonDependentFunType, nme.apply, mt)
19001901
else nonDependentFunType
1901-
case poly @ PolyType(_, mt: MethodType) =>
1902-
assert(!mt.isParamDependent)
1903-
defn.PolyFunctionOf(poly)
1902+
case poly: PolyType =>
1903+
assert(isValidPolyFunctionInfo(poly), s"Not a valid PolyFunction refinement: $poly")
1904+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
19041905
}
19051906

1907+
private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
1908+
def isValidMethodType(info: Type) = info match
1909+
case info: MethodType =>
1910+
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
1911+
&& !info.isParamDependent
1912+
case _ => false
1913+
info match
1914+
case info: PolyType => isValidMethodType(info.resType)
1915+
case _ => isValidMethodType(info)
1916+
19061917
/** The signature of this type. This is by default NotAMethod,
19071918
* but is overridden for PolyTypes, MethodTypes, and TermRef types.
19081919
* (the reason why we deviate from the "final-method-with-pattern-match-in-base-class"
@@ -3724,8 +3735,6 @@ object Types {
37243735

37253736
def companion: LambdaTypeCompanion[ThisName, PInfo, This]
37263737

3727-
def erasedParams(using Context) = List.fill(paramInfos.size)(false)
3728-
37293738
/** The type `[tparams := paramRefs] tp`, where `tparams` can be
37303739
* either a list of type parameter symbols or a list of lambda parameters
37313740
*
@@ -4017,13 +4026,18 @@ object Types {
40174026
final override def isImplicitMethod: Boolean =
40184027
companion.eq(ImplicitMethodType) || isContextualMethod
40194028
final override def hasErasedParams(using Context): Boolean =
4020-
erasedParams.contains(true)
4029+
paramInfos.exists(p => p.hasAnnotation(defn.ErasedParamAnnot))
4030+
40214031
final override def isContextualMethod: Boolean =
40224032
companion.eq(ContextualMethodType)
40234033

4024-
override def erasedParams(using Context): List[Boolean] =
4034+
def erasedParams(using Context): List[Boolean] =
40254035
paramInfos.map(p => p.hasAnnotation(defn.ErasedParamAnnot))
40264036

4037+
def nonErasedParamCount(using Context): Int =
4038+
paramInfos.count(p => !p.hasAnnotation(defn.ErasedParamAnnot))
4039+
4040+
40274041
protected def prefixString: String = companion.prefixString
40284042
}
40294043

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
297297
"(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")"
298298

299299
protected def paramsText(lam: LambdaType): Text = {
300-
val erasedParams = lam.erasedParams
301-
def paramText(ref: ParamRef, erased: Boolean) =
300+
def paramText(ref: ParamRef) =
301+
val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot)
302302
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true)
303-
Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ")
303+
Text(lam.paramRefs.map(paramText), ", ")
304304
}
305305

306306
protected def ParamRefNameString(name: Name): String = nameString(name)

0 commit comments

Comments
 (0)