Skip to content

Commit 18114f0

Browse files
committed
Refactor refined function logic
1 parent 6e370a9 commit 18114f0

28 files changed

+226
-210
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
954954
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
955955
def isStructuralTermSelect(tree: Select) =
956956
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
957-
case defn.PolyFunctionOf(_) =>
957+
case defn.FunctionOf(_) =>
958958
false
959959
case RefinedType(parent, rname, rinfo) =>
960960
rname == tree.name || hasRefinement(parent)

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,13 +1152,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11521152

11531153
def etaExpandCFT(using Context): Tree =
11541154
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1155-
case defn.ContextFunctionType(argTypes, resType, _) =>
1156-
val anonFun = newAnonFun(
1157-
ctx.owner,
1158-
MethodType.companion(isContextual = true)(argTypes, resType),
1159-
coord = ctx.owner.coord)
1155+
case defn.FunctionOf(mt: MethodType) if mt.isContextualMethod && !mt.isResultDependent => // TODO handle result-dependent functions?
1156+
val anonFun = newAnonFun(ctx.owner, mt, coord = ctx.owner.coord)
11601157
def lambdaBody(refss: List[List[Tree]]) =
1161-
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
1158+
expand(target.select(nme.apply).appliedToArgss(refss), mt.resType)(
11621159
using ctx.withOwner(anonFun))
11631160
Closure(anonFun, lambdaBody)
11641161
case _ =>

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,13 +876,13 @@ object CaptureSet:
876876
empty
877877
case CapturingType(parent, refs) =>
878878
recur(parent) ++ refs
879-
case tpd @ RefinedType(parent, _, rinfo: MethodType)
880-
if followResult && defn.isFunctionNType(tpd) =>
881-
ofType(parent, followResult = false) // pick up capture set from parent type
879+
case tpd @ defn.RefinedFunctionOf(rinfo: MethodType)
880+
if followResult =>
881+
ofType(tpd.parent, followResult = false) // pick up capture set from parent type
882882
++ (recur(rinfo.resType) // add capture set of result
883883
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
884884
case tpd @ AppliedType(tycon, args) =>
885-
if followResult && defn.isNonRefinedFunction(tpd) then
885+
if followResult && defn.isFunctionNType(tpd) then
886886
recur(args.last)
887887
// must be (pure) FunctionN type since ImpureFunctions have already
888888
// been eliminated in selector's dealias. Use capture set of result.

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
195195
capt.println(i"solving $t")
196196
refs.solve()
197197
traverse(parent)
198-
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
198+
case defn.RefinedFunctionOf(rinfo) =>
199199
traverse(rinfo)
200200
case tp: TypeVar =>
201201
case tp: TypeRef =>
@@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer:
302302
t
303303
case _ =>
304304
val t1 = t match
305-
case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) =>
306-
t.derivedRefinedType(parent, rname, this(rinfo))
305+
case t @ defn.RefinedFunctionOf(rinfo: MethodType) =>
306+
t.derivedRefinedType(t.parent, t.refinedName, this(rinfo))
307307
case _ =>
308308
mapOver(t)
309309
if variance > 0 then t1
@@ -408,10 +408,10 @@ class CheckCaptures extends Recheck, SymTransformer:
408408
else if meth == defn.Caps_unsafeUnbox then
409409
mapArgUsing(_.forceBoxStatus(false))
410410
else if meth == defn.Caps_unsafeBoxFunArg then
411-
mapArgUsing:
412-
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
413-
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)
414-
411+
mapArgUsing: tp =>
412+
val defn.FunctionOf(mt: MethodType) = tp.dealias: @unchecked
413+
mt.derivedLambdaType(resType = mt.resType.forceBoxStatus(true))
414+
.toFunctionType()
415415
else
416416
super.recheckApply(tree, pt) match
417417
case appType @ CapturingType(appType1, refs) =>
@@ -502,8 +502,9 @@ class CheckCaptures extends Recheck, SymTransformer:
502502
block match
503503
case closureDef(mdef) =>
504504
pt.dealias match
505-
case defn.FunctionOf(ptformals, _, _)
506-
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
505+
case defn.FunctionOf(mt0: MethodType)
506+
if mt0.paramInfos.nonEmpty && mt0.paramInfos.forall(_.captureSet.isAlwaysEmpty) =>
507+
val ptformals = mt0.paramInfos
507508
// Redo setup of the anonymous function so that formal parameters don't
508509
// get capture sets. This is important to avoid false widenings to `cap`
509510
// when taking the base type of the actual closures's dependent function
@@ -696,21 +697,19 @@ class CheckCaptures extends Recheck, SymTransformer:
696697
//println(i"check conforms $actual1 <<< $expected1")
697698
super.checkConformsExpr(actual1, expected1, tree)
698699

699-
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
700-
MethodType.companion(isContextual = isContextual)(args, resultType)
701-
.toFunctionType(isJava = false, alwaysDependent = true)
702-
703700
/** Turn `expected` into a dependent function when `actual` is dependent. */
704701
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
705702
def recur(expected: Type): Type = expected.dealias match
706703
case expected0 @ CapturingType(eparent, refs) =>
707704
val eparent1 = recur(eparent)
708705
if eparent1 eq eparent then expected
709706
else CapturingType(eparent1, refs, boxed = expected0.isBoxed)
710-
case expected @ defn.FunctionOf(args, resultType, isContextual)
711-
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
712-
val expected1 = toDepFun(args, resultType, isContextual)
713-
expected1
707+
case defn.FunctionOf(mt: MethodType) =>
708+
actual.dealias match
709+
case defn.FunctionOf(mt2: MethodType) if mt2.isResultDependent =>
710+
mt.toFunctionType(alwaysDependent = true)
711+
case _ =>
712+
expected
714713
case _ =>
715714
expected
716715
recur(expected)
@@ -781,9 +780,8 @@ class CheckCaptures extends Recheck, SymTransformer:
781780

782781
try
783782
val (eargs, eres) = expected.dealias.stripCapturing match
784-
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
785783
case expected: MethodType => (expected.paramInfos, expected.resType)
786-
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
784+
case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType)
787785
case _ => (aargs.map(_ => WildcardType), WildcardType)
788786
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
789787
val ares1 = adapt(ares, eres, covariant)
@@ -808,7 +806,7 @@ class CheckCaptures extends Recheck, SymTransformer:
808806

809807
try
810808
val eres = expected.dealias.stripCapturing match
811-
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
809+
case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType
812810
case expected: PolyType => expected.resType
813811
case _ => WildcardType
814812

@@ -842,26 +840,26 @@ class CheckCaptures extends Recheck, SymTransformer:
842840

843841
// Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation
844842
val (styp1, leaked) = styp match {
845-
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
843+
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
846844
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
847845
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
848-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
846+
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
849847
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
850848
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
851849
(aargs1, ares1) =>
852850
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
853-
.toFunctionType(isJava = false, alwaysDependent = true))
854-
case actual: MethodType =>
855-
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
856-
(aargs1, ares1) =>
857-
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
858-
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
851+
.toFunctionType(alwaysDependent = true))
852+
case actual @ defn.RefinedFunctionOf(rinfo: PolyType) =>
859853
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
860854
ares1 =>
861855
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
862-
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
856+
val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1)
863857
actual1
864858
)
859+
case actual: MethodType =>
860+
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
861+
(aargs1, ares1) =>
862+
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
865863
case _ =>
866864
(styp, CaptureSet())
867865
}
@@ -1080,7 +1078,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10801078
case CapturingType(parent, refs) =>
10811079
healCaptureSet(refs)
10821080
traverse(parent)
1083-
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
1081+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
10841082
traverse(rinfo)
10851083
case tp: TermLambda =>
10861084
val saved = allowed

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ extends tpd.TreeTraverser:
4040
MethodType.companion(
4141
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
4242
)(argTypes, resType)
43-
.toFunctionType(isJava = false, alwaysDependent = true)
43+
.toFunctionType(alwaysDependent = true)
4444

4545
/** If `tp` is an unboxed capturing type or a function returning an unboxed capturing type,
4646
* convert it to be boxed.
@@ -49,15 +49,15 @@ extends tpd.TreeTraverser:
4949
def recur(tp: Type): Type = tp.dealias match
5050
case tp @ CapturingType(parent, refs) if !tp.isBoxed =>
5151
tp.boxed
52-
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
52+
case tp1 @ AppliedType(tycon, args) if defn.isFunctionNType(tp1) =>
5353
val res = args.last
5454
val boxedRes = recur(res)
5555
if boxedRes eq res then tp
5656
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
57-
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
57+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
5858
val boxedRinfo = recur(rinfo)
5959
if boxedRinfo eq rinfo then tp
60-
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
60+
else boxedRinfo.toFunctionType(alwaysDependent = true)
6161
case tp1: MethodOrPoly =>
6262
val res = tp1.resType
6363
val boxedRes = recur(res)
@@ -129,7 +129,7 @@ extends tpd.TreeTraverser:
129129
apply(parent)
130130
case tp @ AppliedType(tycon, args) =>
131131
val tycon1 = this(tycon)
132-
if defn.isNonRefinedFunction(tp) then
132+
if defn.isFunctionNType(tp) then
133133
// Convert toplevel generic function types to dependent functions
134134
if !defn.isFunctionSymbol(tp.typeSymbol) && (tp.dealias ne tp) then
135135
// This type is a function after dealiasing, so we dealias and recurse.
@@ -149,9 +149,9 @@ extends tpd.TreeTraverser:
149149
tp.derivedAppliedType(tycon1, args1 :+ res1)
150150
else
151151
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
152-
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
152+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
153153
val rinfo1 = apply(rinfo)
154-
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
154+
if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true)
155155
else tp
156156
case tp: MethodType =>
157157
tp.derivedLambdaType(
@@ -197,7 +197,7 @@ extends tpd.TreeTraverser:
197197
val mt = ContextualMethodType(paramName :: Nil)(
198198
_ => paramType :: Nil,
199199
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
200-
val fntpe = defn.PolyFunctionOf(mt)
200+
val fntpe = mt.toFunctionType()
201201
if !encl.isEmpty && isLast then
202202
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
203203
CapturingType(fntpe, cs, boxed = false)

compiler/src/dotty/tools/dotc/cc/Synthetics.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ object Synthetics:
174174
val (et: ExprType) = symd.info: @unchecked
175175
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
176176
def mapFinalResult(tp: Type, f: Type => Type): Type =
177-
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
177+
val defn.FunctionNOf(args, res, isContextual) = tp: @unchecked
178178
if defn.isFunctionNType(res) then
179-
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
179+
defn.FunctionNOf(args, mapFinalResult(res, f), isContextual)
180180
else
181181
f(tp)
182182
val resType1 =

0 commit comments

Comments
 (0)