Skip to content

Commit 9eb04f0

Browse files
committed
Add defn.RefinedFunctionOf extractor
1 parent 9b9d8dd commit 9eb04f0

File tree

5 files changed

+33
-22
lines changed

5 files changed

+33
-22
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,9 +876,9 @@ object CaptureSet:
876876
empty
877877
case CapturingType(parent, refs) =>
878878
recur(parent) ++ refs
879-
case tpd @ RefinedType(parent, _, rinfo: MethodType)
880-
if followResult && defn.isFunctionNType(tpd) =>
881-
ofType(parent, followResult = false) // pick up capture set from parent type
879+
case tpd @ defn.RefinedFunctionOf(rinfo: MethodType)
880+
if followResult =>
881+
ofType(tpd.parent, followResult = false) // pick up capture set from parent type
882882
++ (recur(rinfo.resType) // add capture set of result
883883
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
884884
case tpd @ AppliedType(tycon, args) =>

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
195195
capt.println(i"solving $t")
196196
refs.solve()
197197
traverse(parent)
198-
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
198+
case defn.RefinedFunctionOf(rinfo) =>
199199
traverse(rinfo)
200200
case tp: TypeVar =>
201201
case tp: TypeRef =>
@@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer:
302302
t
303303
case _ =>
304304
val t1 = t match
305-
case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) =>
306-
t.derivedRefinedType(parent, rname, this(rinfo))
305+
case t @ defn.RefinedFunctionOf(rinfo: MethodType) =>
306+
t.derivedRefinedType(t.parent, t.refinedName, this(rinfo))
307307
case _ =>
308308
mapOver(t)
309309
if variance > 0 then t1
@@ -782,7 +782,6 @@ class CheckCaptures extends Recheck, SymTransformer:
782782
val (eargs, eres) = expected.dealias.stripCapturing match
783783
case expected: MethodType => (expected.paramInfos, expected.resType)
784784
case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType)
785-
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
786785
case _ => (aargs.map(_ => WildcardType), WildcardType)
787786
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
788787
val ares1 = adapt(ares, eres, covariant)
@@ -844,23 +843,23 @@ class CheckCaptures extends Recheck, SymTransformer:
844843
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
845844
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
846845
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
847-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
846+
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
848847
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
849848
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
850849
(aargs1, ares1) =>
851850
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
852851
.toFunctionType(alwaysDependent = true))
853-
case actual: MethodType =>
854-
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
855-
(aargs1, ares1) =>
856-
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
857-
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
852+
case actual @ defn.RefinedFunctionOf(rinfo: PolyType) =>
858853
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
859854
ares1 =>
860855
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
861-
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
856+
val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1)
862857
actual1
863858
)
859+
case actual: MethodType =>
860+
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
861+
(aargs1, ares1) =>
862+
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
864863
case _ =>
865864
(styp, CaptureSet())
866865
}
@@ -1079,7 +1078,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10791078
case CapturingType(parent, refs) =>
10801079
healCaptureSet(refs)
10811080
traverse(parent)
1082-
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
1081+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
10831082
traverse(rinfo)
10841083
case tp: TermLambda =>
10851084
val saved = allowed

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
5454
val boxedRes = recur(res)
5555
if boxedRes eq res then tp
5656
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
57-
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
57+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
5858
val boxedRinfo = recur(rinfo)
5959
if boxedRinfo eq rinfo then tp
6060
else boxedRinfo.toFunctionType(alwaysDependent = true)
@@ -149,7 +149,7 @@ extends tpd.TreeTraverser:
149149
tp.derivedAppliedType(tycon1, args1 :+ res1)
150150
else
151151
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
152-
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
152+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
153153
val rinfo1 = apply(rinfo)
154154
if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true)
155155
else tp

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,16 +1114,28 @@ class Definitions {
11141114
*/
11151115
def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
11161116
ft match
1117-
case RefinedType(parent, nme.apply, mt: MethodOrPoly)
1118-
if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
1119-
Some(mt)
1117+
case RefinedFunctionOf(mt) => Some(mt)
11201118
case FunctionNOf(argTypes, resultType, isContextual) =>
11211119
val methodType = if isContextual then ContextualMethodType else MethodType
11221120
Some(methodType(argTypes, resultType))
11231121
case _ => None
11241122
}
11251123
}
11261124

1125+
object RefinedFunctionOf {
1126+
/** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
1127+
* Extracts the method type type and apply info.
1128+
*/
1129+
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = {
1130+
tpe.refinedInfo match
1131+
case mt: MethodOrPoly
1132+
if tpe.refinedName == nme.apply
1133+
&& (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) =>
1134+
Some(mt)
1135+
case _ => None
1136+
}
1137+
}
1138+
11271139
object FunctionNOf {
11281140
/** Create a `FunctionN` or `ContextFunctionN` type applied to the arguments and result type */
11291141
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4106,8 +4106,8 @@ object Types {
41064106
tp.derivedAppliedType(tycon, addInto(args.head) :: Nil)
41074107
case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) =>
41084108
wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last)))
4109-
case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) =>
4110-
wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo)))
4109+
case tp @ defn.RefinedFunctionOf(rinfo) =>
4110+
wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo)))
41114111
case tp: MethodOrPoly =>
41124112
tp.derivedLambdaType(resType = addInto(tp.resType))
41134113
case ExprType(resType) =>

0 commit comments

Comments
 (0)