Skip to content

Commit 7c40fac

Browse files
Handle dependent context functions (#18443)
Add `FunctionTypeOfMethod` extractor that matches any kind of function and return its method type. We use this extractor instead of `ContextFunctionType` to all of * `ContextFunctionN[...]` * `ContextFunctionN[...] { def apply(using ...): R }` where `R` might be dependent on the parameters. * `PolyFunction { def apply(using ...): R }` where `R` might be dependent on the parameters. Currently this one would have at least one erased parameter. The naming of the extractor follows the idea in #18305 (comment).
2 parents 9136582 + d5d8273 commit 7c40fac

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
990990
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
991991
def isStructuralTermSelect(tree: Select) =
992992
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
993-
case defn.PolyFunctionOf(_) =>
993+
case defn.FunctionTypeOfMethod(_) =>
994994
false
995995
case tp: MatchType =>
996996
hasRefinement(tp.tryNormalize)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,24 @@ class Definitions {
11181118
// - .linkedClass: the ClassSymbol of the enumeration (class E)
11191119
sym.owner.linkedClass.typeRef
11201120

1121+
object FunctionTypeOfMethod {
1122+
/** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
1123+
* Extracts the method type type and apply info.
1124+
*/
1125+
def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
1126+
ft match
1127+
case RefinedType(parent, nme.apply, mt: MethodOrPoly)
1128+
if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) =>
1129+
Some(mt)
1130+
case AppliedType(parent, targs) if isFunctionNType(ft) =>
1131+
val isContextual = ft.typeSymbol.name.isContextFunction
1132+
val methodType = if isContextual then ContextualMethodType else MethodType
1133+
Some(methodType(targs.init, targs.last))
1134+
case _ =>
1135+
None
1136+
}
1137+
}
1138+
11211139
object FunctionOf {
11221140
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
11231141
val mt = MethodType.companion(isContextual, false)(args, resultType)

compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ object ContextFunctionResults:
5858
*/
5959
def contextResultsAreErased(sym: Symbol)(using Context): Boolean =
6060
def allErased(tp: Type): Boolean = tp.dealias match
61-
case defn.ContextFunctionType(argTpes, resTpe) =>
62-
argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe)
61+
case ft @ defn.FunctionTypeOfMethod(mt: MethodType) if mt.isContextualMethod =>
62+
mt.nonErasedParamCount == 0 && allErased(mt.resType)
6363
case _ => true
6464
contextResultCount(sym) > 0 && allErased(sym.info.finalResultType)
6565

@@ -68,13 +68,13 @@ object ContextFunctionResults:
6868
*/
6969
def integrateContextResults(tp: Type, crCount: Int)(using Context): Type =
7070
if crCount == 0 then tp
71-
else tp match
71+
else tp.dealias match
7272
case ExprType(rt) =>
7373
integrateContextResults(rt, crCount)
7474
case tp: MethodOrPoly =>
7575
tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount))
76-
case defn.ContextFunctionType(argTypes, resType) =>
77-
MethodType(argTypes, integrateContextResults(resType, crCount - 1))
76+
case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
77+
mt.derivedLambdaType(resType = integrateContextResults(mt.resType, crCount - 1))
7878

7979
/** The total number of parameters of method `sym`, not counting
8080
* erased parameters, but including context result parameters.
@@ -101,7 +101,7 @@ object ContextFunctionResults:
101101
def recur(tp: Type, n: Int): Type =
102102
if n == 0 then tp
103103
else tp match
104-
case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1)
104+
case defn.FunctionTypeOfMethod(mt) => recur(mt.resType, n - 1)
105105
recur(meth.info.finalResultType, depth)
106106

107107
/** Should selection `tree` be eliminated since it refers to an `apply`
@@ -115,8 +115,8 @@ object ContextFunctionResults:
115115
else tree match
116116
case Select(qual, name) =>
117117
if name == nme.apply then
118-
qual.tpe match
119-
case defn.ContextFunctionType(_, _) =>
118+
qual.tpe.nn.dealias match
119+
case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
120120
integrateSelect(qual, n + 1)
121121
case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs
122122
integrateSelect(qual, n + 1)

0 commit comments

Comments
 (0)