From 757a4bee8d493600ee71601c5ea6ae7b1f3a59f3 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 21 Aug 2023 11:01:10 +0200 Subject: [PATCH] Add `defn.RefinedFunctionOf` extractor --- compiler/src/dotty/tools/dotc/cc/CaptureSet.scala | 5 ++--- .../src/dotty/tools/dotc/cc/CheckCaptures.scala | 14 +++++++------- compiler/src/dotty/tools/dotc/cc/Setup.scala | 4 ++-- .../src/dotty/tools/dotc/core/Definitions.scala | 14 ++++++++++++++ compiler/src/dotty/tools/dotc/core/Types.scala | 4 ++-- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 3f2beaa3ff55..84a04c13a91f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -876,9 +876,8 @@ object CaptureSet: empty case CapturingType(parent, refs) => recur(parent) ++ refs - case tpd @ RefinedType(parent, _, rinfo: MethodType) - if followResult && defn.isFunctionNType(tpd) => - ofType(parent, followResult = false) // pick up capture set from parent type + case tpd @ defn.RefinedFunctionOf(rinfo: MethodType) if followResult => + ofType(tpd.parent, followResult = false) // pick up capture set from parent type ++ (recur(rinfo.resType) // add capture set of result -- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters case tpd @ AppliedType(tycon, args) => diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index b6b5d569677c..22e1dbe265cc 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer: capt.println(i"solving $t") refs.solve() traverse(parent) - case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) => + case t @ defn.RefinedFunctionOf(rinfo) => traverse(rinfo) case tp: TypeVar => case tp: TypeRef => @@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer: t case _ => val t1 = t match - case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) => - t.derivedRefinedType(parent, rname, this(rinfo)) + case t @ defn.RefinedFunctionOf(rinfo: MethodType) => + t.derivedRefinedType(t.parent, t.refinedName, this(rinfo)) case _ => mapOver(t) if variance > 0 then t1 @@ -845,7 +845,7 @@ class CheckCaptures extends Recheck, SymTransformer: case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) => adaptFun(actual, args.init, args.last, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1)) - case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) => + case actual @ defn.RefinedFunctionOf(rinfo: MethodType) => // TODO Find a way to combine handling of generic and dependent function types (here and elsewhere) adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox, (aargs1, ares1) => @@ -855,11 +855,11 @@ class CheckCaptures extends Recheck, SymTransformer: adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedLambdaType(paramInfos = aargs1, resType = ares1)) - case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) => + case actual @ defn.RefinedFunctionOf(rinfo: PolyType) => adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox, ares1 => val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1) - val actual1 = actual.derivedRefinedType(p, nme, rinfo1) + val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1) actual1 ) case _ => @@ -1080,7 +1080,7 @@ class CheckCaptures extends Recheck, SymTransformer: case CapturingType(parent, refs) => healCaptureSet(refs) traverse(parent) - case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => traverse(rinfo) case tp: TermLambda => val saved = allowed diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 4c32c2908635..2d00bc7afaa6 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -54,7 +54,7 @@ extends tpd.TreeTraverser: val boxedRes = recur(res) if boxedRes eq res then tp else tp1.derivedAppliedType(tycon, args.init :+ boxedRes) - case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) => + case tp1 @ defn.RefinedFunctionOf(rinfo: MethodType) => val boxedRinfo = recur(rinfo) if boxedRinfo eq rinfo then tp else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true) @@ -149,7 +149,7 @@ extends tpd.TreeTraverser: tp.derivedAppliedType(tycon1, args1 :+ res1) else tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg))) - case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => val rinfo1 = apply(rinfo) if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) else tp diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ea48dd2b56fa..a52b67e88f5d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1129,6 +1129,20 @@ class Definitions { } } + object RefinedFunctionOf { + /** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = { + tpe.refinedInfo match + case mt: MethodOrPoly + if tpe.refinedName == nme.apply + && (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) => + Some(mt) + case _ => None + } + } + object PolyFunctionOf { /** Creates a refined `PolyFunction` with an `apply` method with the given info. */ diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 576f9a6c64f6..d24aeab125c3 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4097,8 +4097,8 @@ object Types { tp.derivedAppliedType(tycon, addInto(args.head) :: Nil) case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) => wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last))) - case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) => - wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo))) + case tp @ defn.RefinedFunctionOf(rinfo) => + wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo))) case tp: MethodOrPoly => tp.derivedLambdaType(resType = addInto(tp.resType)) case ExprType(resType) =>