Skip to content

Refactor dependent function refinement logic #18305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
def isStructuralTermSelect(tree: Select) =
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
case defn.PolyFunctionOf(_) =>
case defn.FunctionOf(_) =>
false
case RefinedType(parent, rname, rinfo) =>
rname == tree.name || hasRefinement(parent)
Expand Down
9 changes: 3 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1152,13 +1152,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, _) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true)(argTypes, resType),
coord = ctx.owner.coord)
case defn.FunctionOf(mt: MethodType) if mt.isContextualMethod && !mt.isResultDependent => // TODO handle result-dependent functions?
val anonFun = newAnonFun(ctx.owner, mt, coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
expand(target.select(nme.apply).appliedToArgss(refss), mt.resType)(
using ctx.withOwner(anonFun))
Closure(anonFun, lambdaBody)
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ object CaptureSet:
++ (recur(rinfo.resType) // add capture set of result
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
case tpd @ AppliedType(tycon, args) =>
if followResult && defn.isNonRefinedFunction(tpd) then
if followResult && defn.isFunctionNType(tpd) then
recur(args.last)
// must be (pure) FunctionN type since ImpureFunctions have already
// been eliminated in selector's dealias. Use capture set of result.
Expand Down
30 changes: 16 additions & 14 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
capt.println(i"solving $t")
refs.solve()
traverse(parent)
case t @ defn.RefinedFunctionOf(rinfo) =>
case defn.RefinedFunctionOf(rinfo) =>
traverse(rinfo)
case tp: TypeVar =>
case tp: TypeRef =>
Expand Down Expand Up @@ -408,10 +408,10 @@ class CheckCaptures extends Recheck, SymTransformer:
else if meth == defn.Caps_unsafeUnbox then
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing:
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)

mapArgUsing: tp =>
val defn.FunctionOf(mt: MethodType) = tp.dealias: @unchecked
mt.derivedLambdaType(resType = mt.resType.forceBoxStatus(true))
.toFunctionType()
else
super.recheckApply(tree, pt) match
case appType @ CapturingType(appType1, refs) =>
Expand Down Expand Up @@ -502,8 +502,9 @@ class CheckCaptures extends Recheck, SymTransformer:
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
case defn.FunctionOf(mt0: MethodType)
if mt0.paramInfos.nonEmpty && mt0.paramInfos.forall(_.captureSet.isAlwaysEmpty) =>
val ptformals = mt0.paramInfos
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `cap`
// when taking the base type of the actual closures's dependent function
Expand Down Expand Up @@ -707,10 +708,12 @@ class CheckCaptures extends Recheck, SymTransformer:
val eparent1 = recur(eparent)
if eparent1 eq eparent then expected
else CapturingType(eparent1, refs, boxed = expected0.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case defn.FunctionOf(mt: MethodType) =>
actual.dealias match
case defn.FunctionOf(mt2: MethodType) if mt2.isResultDependent =>
mt.toFunctionType(alwaysDependent = true)
case _ =>
expected
case _ =>
expected
recur(expected)
Expand Down Expand Up @@ -781,9 +784,8 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType)
case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
val ares1 = adapt(ares, eres, covariant)
Expand Down Expand Up @@ -842,7 +844,7 @@ class CheckCaptures extends Recheck, SymTransformer:

// Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation
val (styp1, leaked) = styp match {
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ extends tpd.TreeTraverser:
def recur(tp: Type): Type = tp.dealias match
case tp @ CapturingType(parent, refs) if !tp.isBoxed =>
tp.boxed
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
case tp1 @ AppliedType(tycon, args) if defn.isFunctionNType(tp1) =>
val res = args.last
val boxedRes = recur(res)
if boxedRes eq res then tp
Expand Down Expand Up @@ -129,7 +129,7 @@ extends tpd.TreeTraverser:
apply(parent)
case tp @ AppliedType(tycon, args) =>
val tycon1 = this(tycon)
if defn.isNonRefinedFunction(tp) then
if defn.isFunctionNType(tp) then
// Convert toplevel generic function types to dependent functions
if !defn.isFunctionSymbol(tp.typeSymbol) && (tp.dealias ne tp) then
// This type is a function after dealiasing, so we dealias and recurse.
Expand Down Expand Up @@ -197,7 +197,7 @@ extends tpd.TreeTraverser:
val mt = ContextualMethodType(paramName :: Nil)(
_ => paramType :: Nil,
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
val fntpe = defn.PolyFunctionOf(mt)
val fntpe = mt.toFunctionType()
if !encl.isEmpty && isLast then
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
CapturingType(fntpe, cs, boxed = false)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/Synthetics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ object Synthetics:
val (et: ExprType) = symd.info: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
def mapFinalResult(tp: Type, f: Type => Type): Type =
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
val defn.FunctionNOf(args, res, isContextual) = tp: @unchecked
if defn.isFunctionNType(res) then
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
defn.FunctionNOf(args, mapFinalResult(res, f), isContextual)
else
f(tp)
val resType1 =
Expand Down
84 changes: 36 additions & 48 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1109,23 +1109,34 @@ class Definitions {
sym.owner.linkedClass.typeRef

object FunctionOf {
/** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
* Extracts the method type type and apply info.
*/
def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
ft match
case RefinedFunctionOf(mt) => Some(mt)
case FunctionNOf(argTypes, resultType, isContextual) =>
val methodType = if isContextual then ContextualMethodType else MethodType
Some(methodType(argTypes, resultType))
Comment on lines +1119 to +1120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice from a conceptual viewpoint, but I'm a bit worried about the performance impact of creating all these method types given that they are uncached types. Maybe the name of the extractor could make it clearer that there's a conversion, e.g. if we call it FromMethod (and we call have a corresponding apply to go in the other direction).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followup in #18443

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's good to make the conversion explicit in the name. The performance aspects are probably OK since the alternative is usually to take the list of AppliedType arguments and split it into inits and last, which is also expensive.

case _ => None
}
}

object FunctionNOf {
/** Create a `FunctionN` or `ContextFunctionN` type applied to the arguments and result type */
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
val mt = MethodType.companion(isContextual, false)(args, resultType)
if mt.hasErasedParams then
RefinedType(PolyFunctionClass.typeRef, nme.apply, mt)
else
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
ft.dealias match
case PolyFunctionOf(mt: MethodType) =>
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
case dft =>
val tsym = dft.typeSymbol
if isFunctionSymbol(tsym) && ft.isRef(tsym) then
val targs = dft.argInfos
if (targs.isEmpty) None
else Some(targs.init, targs.last, tsym.name.isContextFunction)
else None
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)

/** Matches a (possibly aliased) `FunctionN[...]` or `ContextFunctionN[...]`.
* Extracts the list of function argument types, the result type and whether function is contextual.
*/
def unapply(tpe: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
val tsym = tpe.typeSymbol
if isFunctionSymbol(tsym) && tpe.isRef(tsym) then
val targs = tpe.argInfos
if (targs.isEmpty) None
else Some(targs.init, targs.last, tsym.name.isContextFunction)
else None
}
}

Expand Down Expand Up @@ -1165,6 +1176,7 @@ class Definitions {
def isValidMethodType(info: Type) = info match
case info: MethodType =>
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
&& !info.isParamDependent
case _ => false
info match
case info: PolyType => isValidMethodType(info.resType)
Expand Down Expand Up @@ -1731,26 +1743,20 @@ class Definitions {

def isProductSubType(tp: Type)(using Context): Boolean = tp.derivesFrom(ProductClass)

/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ContextFunctionN
* instance?
/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
* - scala.ContextFunctionN
*/
def isNonRefinedFunction(tp: Type)(using Context): Boolean =
val arity = functionArity(tp)
val sym = tp.dealias.typeSymbol
def isFunctionNType(tp: Type)(using Context): Boolean =
val tp1 = tp.dropDependentRefinement
val arity = functionArity(tp1)
val sym = tp1.dealias.typeSymbol

arity >= 0
&& isFunctionClass(sym)
&& tp.isRef(
&& tp1.isRef(
FunctionType(arity, sym.name.isContextFunction).typeSymbol,
skipRefined = false)
end isNonRefinedFunction

/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
* - scala.ContextFunctionN
*/
def isFunctionNType(tp: Type)(using Context): Boolean =
isNonRefinedFunction(tp.dropDependentRefinement)

/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
Expand Down Expand Up @@ -1873,24 +1879,6 @@ class Definitions {
def isContextFunctionType(tp: Type)(using Context): Boolean =
asContextFunctionType(tp).exists

/** An extractor for context function types `As ?=> B`, possibly with
* dependent refinements. Optionally returns a triple consisting of the argument
* types `As`, the result type `B` and a whether the type is an erased context function.
*/
object ContextFunctionType:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it still nice for convenience to keep an extractor just for context functions?

def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] =
if ctx.erasedTypes then
atPhase(erasurePhase)(unapply(tp))
else
asContextFunctionType(tp) match
case PolyFunctionOf(mt: MethodType) =>
Some((mt.paramInfos, mt.resType, mt.erasedParams))
case tp1 if tp1.exists =>
val args = tp1.functionArgInfos
val erasedParams = List.fill(functionArity(tp1)) { false }
Some((args.init, args.last, erasedParams))
case _ => None

/** A whitelist of Scala-2 classes that are known to be pure */
def isAssuredNoInits(sym: Symbol): Boolean =
(sym `eq` SomeClass) || isTupleClass(sym)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ object TypeErasure {
functionType(info.resultType)
case info: MethodType =>
assert(!info.resultType.isInstanceOf[MethodicType])
defn.FunctionType(n = info.erasedParams.count(_ == false))
defn.FunctionType(n = info.nonErasedParamCount)
}
erasure(functionType(applyInfo))
}
Expand Down Expand Up @@ -933,7 +933,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
case tp: TermRef =>
sigName(underlyingOfTermRef(tp))
case ExprType(rt) =>
sigName(defn.FunctionOf(Nil, rt))
sigName(defn.FunctionNOf(Nil, rt))
case tp: TypeVar if !tp.isInstantiated =>
tpnme.Uninstantiated
case tp @ defn.PolyFunctionOf(_) =>
Expand Down
34 changes: 24 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,7 @@ object Types {

/** Dealias, and if result is a dependent function type, drop the `apply` refinement. */
final def dropDependentRefinement(using Context): Type = dealias match {
case RefinedType(parent, nme.apply, mt) if defn.isNonRefinedFunction(parent) => parent
case RefinedType(parent, nme.apply, mt) if defn.isFunctionNType(parent) => parent
case tp => tp
}

Expand Down Expand Up @@ -1890,19 +1890,30 @@ object Types {
case res: MethodType => res.toFunctionType(isJava)
case res => res
}
defn.FunctionOf(
defn.FunctionNOf(
mt.paramInfos.mapConserve(_.translateFromRepeated(toArray = isJava)),
result1, isContextual)
if mt.hasErasedParams then
defn.PolyFunctionOf(mt)
assert(isValidPolyFunctionInfo(mt), s"Not a valid PolyFunction refinement: $mt")
RefinedType(defn.PolyFunctionType, nme.apply, mt)
else if alwaysDependent || mt.isResultDependent then
RefinedType(nonDependentFunType, nme.apply, mt)
else nonDependentFunType
case poly @ PolyType(_, mt: MethodType) =>
assert(!mt.isParamDependent)
defn.PolyFunctionOf(poly)
case poly: PolyType =>
assert(isValidPolyFunctionInfo(poly), s"Not a valid PolyFunction refinement: $poly")
RefinedType(defn.PolyFunctionType, nme.apply, poly)
}

private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidMethodType(info: Type) = info match
case info: MethodType =>
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
&& !info.isParamDependent
case _ => false
info match
case info: PolyType => isValidMethodType(info.resType)
case _ => isValidMethodType(info)

/** The signature of this type. This is by default NotAMethod,
* but is overridden for PolyTypes, MethodTypes, and TermRef types.
* (the reason why we deviate from the "final-method-with-pattern-match-in-base-class"
Expand Down Expand Up @@ -3724,8 +3735,6 @@ object Types {

def companion: LambdaTypeCompanion[ThisName, PInfo, This]

def erasedParams(using Context) = List.fill(paramInfos.size)(false)

/** The type `[tparams := paramRefs] tp`, where `tparams` can be
* either a list of type parameter symbols or a list of lambda parameters
*
Expand Down Expand Up @@ -4017,13 +4026,18 @@ object Types {
final override def isImplicitMethod: Boolean =
companion.eq(ImplicitMethodType) || isContextualMethod
final override def hasErasedParams(using Context): Boolean =
erasedParams.contains(true)
paramInfos.exists(p => p.hasAnnotation(defn.ErasedParamAnnot))

final override def isContextualMethod: Boolean =
companion.eq(ContextualMethodType)

override def erasedParams(using Context): List[Boolean] =
def erasedParams(using Context): List[Boolean] =
paramInfos.map(p => p.hasAnnotation(defn.ErasedParamAnnot))

def nonErasedParamCount(using Context): Int =
paramInfos.count(p => !p.hasAnnotation(defn.ErasedParamAnnot))


protected def prefixString: String = companion.prefixString
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
"(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")"

protected def paramsText(lam: LambdaType): Text = {
val erasedParams = lam.erasedParams
def paramText(ref: ParamRef, erased: Boolean) =
def paramText(ref: ParamRef) =
val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot)
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true)
Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ")
Text(lam.paramRefs.map(paramText), ", ")
}

protected def ParamRefNameString(name: Name): String = nameString(name)
Expand Down
Loading