Skip to content

Commit 7061e48

Browse files
committed
Add DependentFunctionRefinementOf
1 parent c13c1c9 commit 7061e48

File tree

4 files changed

+26
-10
lines changed

4 files changed

+26
-10
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,13 @@ class CheckCaptures extends Recheck, SymTransformer:
707707
val eparent1 = recur(eparent)
708708
if eparent1 eq eparent then expected
709709
else CapturingType(eparent1, refs, boxed = expected0.isBoxed)
710-
case expected @ defn.FunctionOf(args, resultType, isContextual)
711-
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
712-
val expected1 = toDepFun(args, resultType, isContextual)
713-
expected1
710+
case defn.DependentFunctionRefinementOf(_, _) =>
711+
expected
712+
case expected @ defn.FunctionOf(args, resultType, isContextual) =>
713+
actual match
714+
case defn.DependentFunctionRefinementOf(_, _) => expected
715+
case _ if defn.isFunctionNType(actual) => toDepFun(args, resultType, isContextual)
716+
case _ => expected
714717
case _ =>
715718
expected
716719
recur(expected)
@@ -842,7 +845,7 @@ class CheckCaptures extends Recheck, SymTransformer:
842845

843846
// Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation
844847
val (styp1, leaked) = styp match {
845-
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
848+
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
846849
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
847850
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
848851
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,17 @@ class Definitions {
11741174
case _ => isValidMethodType(info)
11751175
}
11761176

1177+
object DependentFunctionRefinementOf {
1178+
/** Matches a refined function FT type and extracts FT and apply info.
1179+
*
1180+
* Pattern: `$ft { def apply: $mt }`
1181+
*/
1182+
def unapply(ft: Type)(using Context): Option[(Type, MethodType)] = ft.dealias match
1183+
case RefinedType(parent, nme.apply, mt: MethodType) if isNonRefinedFunction(parent) =>
1184+
Some((parent, mt))
1185+
case _ => None
1186+
}
1187+
11771188
object PartialFunctionOf {
11781189
def apply(arg: Type, result: Type)(using Context): Type =
11791190
PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil)

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,9 +1893,11 @@ class Namer { typer: Typer =>
18931893
val originalTp = defaultParamType
18941894
val approxTp = wildApprox(originalTp)
18951895
approxTp.stripPoly match
1896+
case defn.DependentFunctionRefinementOf(_, mt) if mt.isContextualMethod =>
1897+
// in this case `resType` is lying, gives us only the non-dependent upper bound
1898+
originalTp
18961899
case atp @ defn.ContextFunctionOf(_, resType)
1897-
if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound
1898-
|| resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) =>
1900+
if resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) =>
18991901
originalTp
19001902
case _ =>
19011903
approxTp

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,9 +1799,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
17991799
case PolyType(_, _, mt1) => mt1.hasErasedParams
18001800
case _ => false
18011801
def isDependentFunctionType: Boolean =
1802-
val tpNoRefinement = self.dropDependentRefinement
1803-
tpNoRefinement != self
1804-
&& dotc.core.Symbols.defn.isNonRefinedFunction(tpNoRefinement)
1802+
self match
1803+
case dotc.core.Symbols.defn.DependentFunctionRefinementOf(_, _) => true
1804+
case _ => false
18051805
def isTupleN: Boolean =
18061806
dotc.core.Symbols.defn.isTupleNType(self)
18071807
def select(sym: Symbol): TypeRepr = self.select(sym)

0 commit comments

Comments
 (0)