Skip to content

Backport "Add defn.RefinedFunctionOf extractor" to LTS #20662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class CheckCaptures extends Recheck, SymTransformer:
capt.println(i"solving $t")
refs.solve()
traverse(parent)
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
case t @ defn.RefinedFunctionOf(rinfo) =>
traverse(rinfo)
case tp: TypeVar =>
case tp: TypeRef =>
Expand Down Expand Up @@ -769,7 +769,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
Expand All @@ -779,11 +779,11 @@ class CheckCaptures extends Recheck, SymTransformer:
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
case actual @ defn.RefinedFunctionOf(rinfo: PolyType) =>
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
ares1 =>
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1)
actual1
)
case _ =>
Expand Down Expand Up @@ -996,7 +996,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case defn.RefinedFunctionOf(rinfo: MethodType) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
case tp1 @ defn.RefinedFunctionOf(rinfo: MethodType) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(alwaysDependent = true)
Expand Down Expand Up @@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case defn.RefinedFunctionOf(rinfo: MethodType) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true)
else tp
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,20 @@ class Definitions {
case _ => None
}

object RefinedFunctionOf {
/** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
* Extracts the method type type and apply info.
*/
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = {
tpe.refinedInfo match
case mt: MethodOrPoly
if tpe.refinedName == nme.apply
&& (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) =>
Some(mt)
case _ => None
}
}

object PolyFunctionOf {
/** Matches a refined `PolyFunction` type and extracts the apply info.
*
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4055,8 +4055,8 @@ object Types {
tp.derivedAppliedType(tycon, addInto(args.head) :: Nil)
case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) =>
wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last)))
case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) =>
wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo)))
case tp @ defn.RefinedFunctionOf(rinfo) =>
wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo)))
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = addInto(tp.resType))
case ExprType(resType) =>
Expand Down
Loading