Skip to content

Commit 019d203

Browse files
Inferring tracked (#21628)
Infer `tracked` for parameters that are referenced in the public signatures of the defining class. e.g. ```scala 3 class OrdSet(val ord: Ordering) { type Set = List[ord.T] def empty: Set = Nil implicit class helper(s: Set) { def add(x: ord.T): Set = x :: remove(x) def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0) def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0) } } ``` In the example above, `ord` is referenced in the signatures of the public members of `OrdSet`, so a `tracked` modifier will be inserted automatically. Aldo generalize the condition for infering tracked for context bounds. Explicit `using val` witnesses will now also be `tracked` by default. This implementation should be safe with regards to not introducing spurious cyclic reference errors. Current limitations (I'll create separate issues for them, once this is merged): - Inferring `tracked` for given classes is done after the desugaring to class + def, so the def doesn't know about `tracked` being set on the original constructor parameter. This might be worked around by watching the original symbol or adding an attachment pointer to the implicit wrapper. ```scala 3 given mInst: (c: C) => M: def foo: c.T = c.foo ``` - Passing parameters as an **inferred** `tracked` arguments in parents doesn't work, since forcing a parent (term) isn't safe. This can be replaced with a lint that is checked after Namer.
1 parent 312c89a commit 019d203

9 files changed

+347
-34
lines changed

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+3
Original file line numberDiff line numberDiff line change
@@ -2754,6 +2754,9 @@ object SymDenotations {
27542754
/** Sets all missing fields of given denotation */
27552755
def complete(denot: SymDenotation)(using Context): Unit
27562756

2757+
/** Is this a completer for an explicit type tree */
2758+
def isExplicit: Boolean = false
2759+
27572760
def apply(sym: Symbol): LazyType = this
27582761
def apply(module: TermSymbol, modcls: ClassSymbol): LazyType = this
27592762

compiler/src/dotty/tools/dotc/typer/Namer.scala

+101-33
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ class Namer { typer: Typer =>
278278
if rhs.isEmpty || flags.is(Opaque) then flags |= Deferred
279279
if flags.is(Param) then tree.rhs else analyzeRHS(tree.rhs)
280280

281+
def hasExplicitType(tree: ValOrDefDef): Boolean =
282+
!tree.tpt.isEmpty || tree.mods.isOneOf(TermParamOrAccessor)
283+
281284
// to complete a constructor, move one context further out -- this
282285
// is the context enclosing the class. Note that the context in which a
283286
// constructor is recorded and the context in which it is completed are
@@ -291,6 +294,8 @@ class Namer { typer: Typer =>
291294

292295
val completer = tree match
293296
case tree: TypeDef => TypeDefCompleter(tree)(cctx)
297+
case tree: ValOrDefDef if Feature.enabled(Feature.modularity) && hasExplicitType(tree) =>
298+
new Completer(tree, isExplicit = true)(cctx)
294299
case _ => Completer(tree)(cctx)
295300
val info = adjustIfModule(completer, tree)
296301
createOrRefine[Symbol](tree, name, flags, ctx.owner, _ => info,
@@ -800,7 +805,7 @@ class Namer { typer: Typer =>
800805
}
801806

802807
/** The completer of a symbol defined by a member def or import (except ClassSymbols) */
803-
class Completer(val original: Tree)(ictx: Context) extends LazyType with SymbolLoaders.SecondCompleter {
808+
class Completer(val original: Tree, override val isExplicit: Boolean = false)(ictx: Context) extends LazyType with SymbolLoaders.SecondCompleter {
804809

805810
protected def localContext(owner: Symbol): FreshContext = ctx.fresh.setOwner(owner).setTree(original)
806811

@@ -1783,7 +1788,7 @@ class Namer { typer: Typer =>
17831788
sym.owner.typeParams.foreach(_.ensureCompleted())
17841789
completeTrailingParamss(constr, sym, indexingCtor = true)
17851790
if Feature.enabled(modularity) then
1786-
constr.termParamss.foreach(_.foreach(setTracked))
1791+
constr.termParamss.foreach(_.foreach(setTrackedConstrParam))
17871792

17881793
/** The signature of a module valdef.
17891794
* This will compute the corresponding module class TypeRef immediately
@@ -1923,22 +1928,24 @@ class Namer { typer: Typer =>
19231928
def wrapRefinedMethType(restpe: Type): Type =
19241929
wrapMethType(addParamRefinements(restpe, paramSymss))
19251930

1931+
def addTrackedIfNeeded(ddef: DefDef, owningSym: Symbol): Unit =
1932+
for params <- ddef.termParamss; param <- params do
1933+
val psym = symbolOfTree(param)
1934+
if needsTracked(psym, param, owningSym) then
1935+
psym.setFlag(Tracked)
1936+
setParamTrackedWithAccessors(psym, sym.maybeOwner.infoOrCompleter)
1937+
1938+
if Feature.enabled(modularity) then addTrackedIfNeeded(ddef, sym.maybeOwner)
1939+
19261940
if isConstructor then
19271941
// set result type tree to unit, but take the current class as result type of the symbol
19281942
typedAheadType(ddef.tpt, defn.UnitType)
19291943
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
19301944
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
19311945
mt
1932-
else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then
1933-
// set every context bound evidence parameter of a given companion method
1934-
// to be tracked, provided it has a type that has an abstract type member.
1935-
// Add refinements for all tracked parameters to the result type.
1936-
for params <- ddef.termParamss; param <- params do
1937-
val psym = symbolOfTree(param)
1938-
if needsTracked(psym, param) then psym.setFlag(Tracked)
1939-
valOrDefDefSig(ddef, sym, paramSymss, wrapRefinedMethType)
19401946
else
1941-
valOrDefDefSig(ddef, sym, paramSymss, wrapMethType)
1947+
val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType
1948+
valOrDefDefSig(ddef, sym, paramSymss, paramFn)
19421949
end defDefSig
19431950

19441951
/** Complete the trailing parameters of a DefDef,
@@ -1987,36 +1994,97 @@ class Namer { typer: Typer =>
19871994
cls.srcPos)
19881995
case _ =>
19891996

1990-
/** Under x.modularity, we add `tracked` to context bound witnesses
1991-
* that have abstract type members
1997+
private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit =
1998+
for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do
1999+
acc.resetFlag(PrivateLocal)
2000+
psym.setFlag(Tracked)
2001+
acc.setFlag(Tracked)
2002+
2003+
/** `psym` needs tracked if it is referenced in any of the public signatures
2004+
* of the defining class or when `psym` is a context bound witness with an
2005+
* abstract type member
19922006
*/
1993-
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
1994-
!sym.is(Tracked)
1995-
&& param.hasAttachment(ContextBoundParam)
1996-
&& sym.info.memberNames(abstractTypeNameFilter).nonEmpty
1997-
1998-
/** Under x.modularity, set every context bound evidence parameter of a class to be tracked,
1999-
* provided it has a type that has an abstract type member. Reset private and local flags
2000-
* so that the parameter becomes a `val`.
2007+
def needsTracked(psym: Symbol, param: ValDef, owningSym: Symbol)(using Context) =
2008+
lazy val abstractContextBound = isContextBoundWitnessWithAbstractMembers(psym, param, owningSym)
2009+
lazy val isRefInSignatures =
2010+
psym.maybeOwner.isPrimaryConstructor
2011+
&& isReferencedInPublicSignatures(psym)
2012+
!psym.is(Tracked)
2013+
&& psym.isTerm
2014+
&& (
2015+
abstractContextBound
2016+
|| isRefInSignatures
2017+
)
2018+
2019+
/** Under x.modularity, we add `tracked` to context bound witnesses and
2020+
* explicit evidence parameters that have abstract type members
2021+
*/
2022+
private def isContextBoundWitnessWithAbstractMembers(psym: Symbol, param: ValDef, owningSym: Symbol)(using Context): Boolean =
2023+
val accessorSyms = maybeParamAccessors(owningSym, psym)
2024+
(owningSym.isClass || owningSym.isAllOf(Given | Method))
2025+
&& (param.hasAttachment(ContextBoundParam) || (psym.isOneOf(GivenOrImplicit) && !accessorSyms.forall(_.isOneOf(PrivateLocal))))
2026+
&& psym.info.memberNames(abstractTypeNameFilter).nonEmpty
2027+
2028+
extension (sym: Symbol)
2029+
private def infoWithForceNonInferingCompleter(using Context): Type = sym.infoOrCompleter match
2030+
case tpe: LazyType if tpe.isExplicit => sym.info
2031+
case tpe if sym.isType => sym.info
2032+
case info => info
2033+
2034+
/** Under x.modularity, we add `tracked` to term parameters whose types are
2035+
* referenced in public signatures of the defining class
2036+
*/
2037+
private def isReferencedInPublicSignatures(sym: Symbol)(using Context): Boolean =
2038+
val owner = sym.maybeOwner.maybeOwner
2039+
val accessorSyms = maybeParamAccessors(owner, sym)
2040+
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
2041+
owner.infoOrCompleter match
2042+
case info: ClassInfo =>
2043+
info.decls.filter(_.isPublic)
2044+
.filter(_ != sym.maybeOwner)
2045+
.exists { decl =>
2046+
tpeContainsSymbolRef(decl.infoWithForceNonInferingCompleter, accessorSyms)
2047+
}
2048+
case _ => false
2049+
checkOwnerMemberSignatures(owner)
2050+
2051+
/** Check if any of syms are referenced in tpe */
2052+
private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean =
2053+
val acc = new ExistsAccumulator(
2054+
{ tpe => tpe.termSymbol.exists && syms.contains(tpe.termSymbol) },
2055+
StopAt.Static,
2056+
forceLazy = false
2057+
) {
2058+
override def apply(acc: Boolean, tpe: Type): Boolean = super.apply(acc, tpe.safeDealias)
2059+
}
2060+
acc(false, tpe)
2061+
2062+
private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] = owner.infoOrCompleter match
2063+
case info: ClassInfo =>
2064+
info.decls.lookupAll(sym.name).filter(d => d.is(ParamAccessor)).toList
2065+
case _ => List(sym)
2066+
2067+
/** Under x.modularity, set every context bound evidence parameter or public
2068+
* using parameter of a class to be tracked, provided it has a type that has
2069+
* an abstract type member. Reset private and local flags so that the
2070+
* parameter becomes a `val`.
20012071
*/
2002-
def setTracked(param: ValDef)(using Context): Unit =
2072+
def setTrackedConstrParam(param: ValDef)(using Context): Unit =
20032073
val sym = symbolOfTree(param)
20042074
sym.maybeOwner.maybeOwner.infoOrCompleter match
2005-
case info: ClassInfo if needsTracked(sym, param) =>
2075+
case info: ClassInfo
2076+
if !sym.is(Tracked) && isContextBoundWitnessWithAbstractMembers(sym, param, sym.maybeOwner.maybeOwner) =>
20062077
typr.println(i"set tracked $param, $sym: ${sym.info} containing ${sym.info.memberNames(abstractTypeNameFilter).toList}")
2007-
for acc <- info.decls.lookupAll(sym.name) if acc.is(ParamAccessor) do
2008-
acc.resetFlag(PrivateLocal)
2009-
acc.setFlag(Tracked)
2010-
sym.setFlag(Tracked)
2078+
setParamTrackedWithAccessors(sym, info)
20112079
case _ =>
20122080

20132081
def inferredResultType(
2014-
mdef: ValOrDefDef,
2015-
sym: Symbol,
2016-
paramss: List[List[Symbol]],
2017-
paramFn: Type => Type,
2018-
fallbackProto: Type
2019-
)(using Context): Type =
2082+
mdef: ValOrDefDef,
2083+
sym: Symbol,
2084+
paramss: List[List[Symbol]],
2085+
paramFn: Type => Type,
2086+
fallbackProto: Type
2087+
)(using Context): Type =
20202088
/** Is this member tracked? This is true if it is marked as `tracked` or if
20212089
* it overrides a `tracked` member. To account for the later, `isTracked`
20222090
* is overriden to `true` as a side-effect of computing `inherited`.

docs/_docs/reference/experimental/modularity.md

+40-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,46 @@ This works as it should now. Without the addition of `tracked` to the
108108
parameter of `SetFunctor` typechecking would immediately lose track of
109109
the element type `T` after an `add`, and would therefore fail.
110110

111-
**Discussion**
111+
**Syntax Change**
112+
113+
```
114+
ClsParam ::= {Annotation} [{Modifier | ‘tracked’} (‘val’ | ‘var’)] Param
115+
```
116+
117+
The (soft) `tracked` modifier is only allowed for `val` parameters of classes.
118+
119+
### Tracked inference
120+
121+
In some cases `tracked` can be infered and doesn't have to be written
122+
explicitly. A common such case is when a class parameter is referenced in the
123+
signatures of the public members of the class. e.g.
124+
```scala 3
125+
class OrdSet(val ord: Ordering) {
126+
type Set = List[ord.T]
127+
def empty: Set = Nil
128+
129+
implicit class helper(s: Set) {
130+
def add(x: ord.T): Set = x :: remove(x)
131+
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
132+
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
133+
}
134+
}
135+
```
136+
In the example above, `ord` is referenced in the signatures of the public
137+
members of `OrdSet`, so a `tracked` modifier will be inserted automatically.
138+
139+
Another common case is when a context bound has an associated type (i.e. an abstract type member) e.g.
140+
```scala 3
141+
trait TC:
142+
type Self
143+
type T
144+
145+
class Klass[A: {TC as tc}]
146+
```
147+
148+
Here, `tc` is a context bound with an associated type `T`, so `tracked` will be inferred for `tc`.
149+
150+
### Discussion
112151

113152
Since `tracked` is so useful, why not assume it by default? First, `tracked` makes sense only for `val` parameters. If a class parameter is not also a field declared using `val` then there's nothing to refine in the constructor result type. One could think of at least making all `val` parameters tracked by default, but that would be a backwards incompatible change. For instance, the following code would break:
114153

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.language.experimental.modularity
2+
3+
trait T:
4+
type Self
5+
type X
6+
def foo: Self
7+
8+
class D[C](using wd: C is T)
9+
class E(using we: Int is T)
10+
11+
def Test =
12+
given w: Int is T:
13+
def foo: Int = 42
14+
type X = Long
15+
val d = D(using w)
16+
summon[d.wd.X =:= Long] // error
17+
val e = E(using w)
18+
summon[e.we.X =:= Long] // error

tests/pos/infer-tracked-1.scala

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
trait Ordering {
5+
type T
6+
def compare(t1:T, t2: T): Int
7+
}
8+
9+
class SetFunctor(val ord: Ordering) {
10+
type Set = List[ord.T]
11+
def empty: Set = Nil
12+
13+
implicit class helper(s: Set) {
14+
def add(x: ord.T): Set = x :: remove(x)
15+
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
16+
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
17+
}
18+
}
19+
20+
object Test {
21+
val orderInt = new Ordering {
22+
type T = Int
23+
def compare(t1: T, t2: T): Int = t1 - t2
24+
}
25+
26+
val IntSet = new SetFunctor(orderInt)
27+
import IntSet.*
28+
29+
def main(args: Array[String]) = {
30+
val set = IntSet.empty.add(6).add(8).add(23)
31+
assert(!set.member(7))
32+
assert(set.member(8))
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.language.experimental.modularity
2+
3+
trait T:
4+
type Self
5+
type X
6+
def foo: Self
7+
8+
class D[C](using val wd: C is T)
9+
class E(using val we: Int is T)
10+
11+
def Test =
12+
given w: Int is T:
13+
def foo: Int = 42
14+
type X = Long
15+
val d = D(using w)
16+
summon[d.wd.X =:= Long]
17+
val e = E(using w)
18+
summon[e.we.X =:= Long]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
trait WithValue { type Value = Int }
5+
6+
case class Year(value: Int) extends WithValue {
7+
val x: Value = 2
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
import collection.mutable
5+
6+
/// A parser combinator.
7+
trait Combinator[T]:
8+
9+
/// The context from which elements are being parsed, typically a stream of tokens.
10+
type Context
11+
/// The element being parsed.
12+
type Element
13+
14+
extension (self: T)
15+
/// Parses and returns an element from `context`.
16+
def parse(context: Context): Option[Element]
17+
end Combinator
18+
19+
final case class Apply[C, E](action: C => Option[E])
20+
final case class Combine[A, B](first: A, second: B)
21+
22+
object test:
23+
24+
class apply[C, E] extends Combinator[Apply[C, E]]:
25+
type Context = C
26+
type Element = E
27+
extension(self: Apply[C, E])
28+
def parse(context: C): Option[E] = self.action(context)
29+
30+
def apply[C, E]: apply[C, E] = new apply[C, E]
31+
32+
class combine[A, B](
33+
val f: Combinator[A],
34+
val s: Combinator[B] { type Context = f.Context}
35+
) extends Combinator[Combine[A, B]]:
36+
type Context = f.Context
37+
type Element = (f.Element, s.Element)
38+
extension(self: Combine[A, B])
39+
def parse(context: Context): Option[Element] = ???
40+
41+
def combine[A, B](
42+
_f: Combinator[A],
43+
_s: Combinator[B] { type Context = _f.Context}
44+
) = new combine[A, B](_f, _s)
45+
// cast is needed since the type of new combine[A, B](_f, _s)
46+
// drops the required refinement.
47+
48+
extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
49+
if buf.isEmpty then None
50+
else try Some(buf.head) finally buf.remove(0)
51+
52+
@main def hello: Unit = {
53+
val source = (0 to 10).toList
54+
val stream = source.to(mutable.ListBuffer)
55+
56+
val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
57+
val m = Combine(n, n)
58+
59+
val c = combine(
60+
apply[mutable.ListBuffer[Int], Int],
61+
apply[mutable.ListBuffer[Int], Int]
62+
)
63+
val r = c.parse(m)(stream) // was type mismatch, now OK
64+
val rc: Option[(Int, Int)] = r
65+
}

0 commit comments

Comments
 (0)