Skip to content

Commit 166bd97

Browse files
committed
Check for parameter references in type bounds when infering tracked
1 parent 41a5126 commit 166bd97

6 files changed

+281
-10
lines changed

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import config.Feature.{sourceVersion, modularity}
3030
import config.SourceVersion.*
3131

3232
import scala.compiletime.uninitialized
33+
import dotty.tools.dotc.transform.init.Util.tree
3334

3435
/** This class creates symbols from definitions and imports and gives them
3536
* lazy types.
@@ -1648,7 +1649,6 @@ class Namer { typer: Typer =>
16481649
* as an attachment on the ClassDef tree.
16491650
*/
16501651
def enterParentRefinementSyms(refinements: List[(Name, Type)]) =
1651-
println(s"For class $cls, entering parent refinements: $refinements")
16521652
val refinedSyms = mutable.ListBuffer[Symbol]()
16531653
for (name, tp) <- refinements do
16541654
if decls.lookupEntry(name) == null then
@@ -1658,7 +1658,6 @@ class Namer { typer: Typer =>
16581658
case _ => Synthetic | Deferred
16591659
val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
16601660
refinedSyms += s
1661-
println(s" entered $s")
16621661
if refinedSyms.nonEmpty then
16631662
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
16641663
original.pushAttachment(ParentRefinements, refinedSyms.toList)
@@ -1996,10 +1995,11 @@ class Namer { typer: Typer =>
19961995
*/
19971996
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
19981997
!sym.is(Tracked)
1998+
&& sym.maybeOwner.isConstructor
19991999
&& (
20002000
isContextBoundWitnessWithAbstractMembers(sym, param)
20012001
|| isReferencedInPublicSignatures(sym)
2002-
// || isPassedToTrackedParentParameter(sym, param)
2002+
|| isPassedToTrackedParentParameter(sym, param)
20032003
)
20042004

20052005
/** Under x.modularity, we add `tracked` to context bound witnesses
@@ -2018,11 +2018,14 @@ class Namer { typer: Typer =>
20182018
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
20192019
owner.infoOrCompleter match
20202020
case info: ClassInfo =>
2021-
info.decls.filter(d => !d.isConstructor).exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
2021+
info.decls.filter(_.isTerm)
2022+
.filter(_ != sym.maybeOwner)
2023+
.exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
20222024
case _ => false
20232025
checkOwnerMemberSignatures(owner)
20242026

20252027
def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean =
2028+
// TODO(kπ) Add tracked if the param is passed as a tracked arg in parent. Can we touch the inheritance terms?
20262029
val owner = sym.maybeOwner.maybeOwner
20272030
val accessorSyms = maybeParamAccessors(owner, sym)
20282031
owner.infoOrCompleter match
@@ -2035,10 +2038,18 @@ class Namer { typer: Typer =>
20352038
case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms)
20362039
case _ => false
20372040

2038-
private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean =
2039-
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
2040-
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
2041-
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)
2041+
private def tpeContainsSymbolRef(tpe0: Type, syms: List[Symbol])(using Context): Boolean =
2042+
val tpe = tpe0.dropAlias.widenExpr.dealias
2043+
tpe match
2044+
case m : MethodOrPoly =>
2045+
m.paramInfos.exists(tpeContainsSymbolRef(_, syms))
2046+
|| tpeContainsSymbolRef(m.resultType, syms)
2047+
case r @ RefinedType(parent, _, refinedInfo) => tpeContainsSymbolRef(parent, syms) || tpeContainsSymbolRef(refinedInfo, syms)
2048+
case TypeBounds(lo, hi) => tpeContainsSymbolRef(lo, syms) || tpeContainsSymbolRef(hi, syms)
2049+
case t: Type =>
2050+
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
2051+
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
2052+
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)
20422053

20432054
private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] =
20442055
owner.infoOrCompleter match

tests/pos/infer-tracked-1.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
trait Ordering {
5+
type T
6+
def compare(t1:T, t2: T): Int
7+
}
8+
9+
class SetFunctor(val ord: Ordering) {
10+
type Set = List[ord.T]
11+
def empty: Set = Nil
12+
13+
implicit class helper(s: Set) {
14+
def add(x: ord.T): Set = x :: remove(x)
15+
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
16+
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
17+
}
18+
}
19+
20+
object Test {
21+
val orderInt = new Ordering {
22+
type T = Int
23+
def compare(t1: T, t2: T): Int = t1 - t2
24+
}
25+
26+
val IntSet = new SetFunctor(orderInt)
27+
import IntSet.*
28+
29+
def main(args: Array[String]) = {
30+
val set = IntSet.empty.add(6).add(8).add(23)
31+
assert(!set.member(7))
32+
assert(set.member(8))
33+
}
34+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
import collection.mutable
5+
6+
/// A parser combinator.
7+
trait Combinator[T]:
8+
9+
/// The context from which elements are being parsed, typically a stream of tokens.
10+
type Context
11+
/// The element being parsed.
12+
type Element
13+
14+
extension (self: T)
15+
/// Parses and returns an element from `context`.
16+
def parse(context: Context): Option[Element]
17+
end Combinator
18+
19+
final case class Apply[C, E](action: C => Option[E])
20+
final case class Combine[A, B](first: A, second: B)
21+
22+
object test:
23+
24+
class apply[C, E] extends Combinator[Apply[C, E]]:
25+
type Context = C
26+
type Element = E
27+
extension(self: Apply[C, E])
28+
def parse(context: C): Option[E] = self.action(context)
29+
30+
def apply[C, E]: apply[C, E] = new apply[C, E]
31+
32+
class combine[A, B](
33+
val f: Combinator[A],
34+
val s: Combinator[B] { type Context = f.Context}
35+
) extends Combinator[Combine[A, B]]:
36+
type Context = f.Context
37+
type Element = (f.Element, s.Element)
38+
extension(self: Combine[A, B])
39+
def parse(context: Context): Option[Element] = ???
40+
41+
def combine[A, B](
42+
_f: Combinator[A],
43+
_s: Combinator[B] { type Context = _f.Context}
44+
) = new combine[A, B](_f, _s)
45+
// cast is needed since the type of new combine[A, B](_f, _s)
46+
// drops the required refinement.
47+
48+
extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
49+
if buf.isEmpty then None
50+
else try Some(buf.head) finally buf.remove(0)
51+
52+
@main def hello: Unit = {
53+
val source = (0 to 10).toList
54+
val stream = source.to(mutable.ListBuffer)
55+
56+
val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
57+
val m = Combine(n, n)
58+
59+
val c = combine(
60+
apply[mutable.ListBuffer[Int], Int],
61+
apply[mutable.ListBuffer[Int], Int]
62+
)
63+
val r = c.parse(m)(stream) // was type mismatch, now OK
64+
val rc: Option[(Int, Int)] = r
65+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
import collection.mutable
5+
6+
/// A parser combinator.
7+
trait Combinator[T]:
8+
9+
/// The context from which elements are being parsed, typically a stream of tokens.
10+
type Context
11+
/// The element being parsed.
12+
type Element
13+
14+
extension (self: T)
15+
/// Parses and returns an element from `context`.
16+
def parse(context: Context): Option[Element]
17+
end Combinator
18+
19+
final case class Apply[C, E](action: C => Option[E])
20+
final case class Combine[A, B](first: A, second: B)
21+
22+
given apply[C, E]: Combinator[Apply[C, E]] with {
23+
type Context = C
24+
type Element = E
25+
extension(self: Apply[C, E]) {
26+
def parse(context: C): Option[E] = self.action(context)
27+
}
28+
}
29+
30+
given combine[A, B](using
31+
val f: Combinator[A],
32+
val s: Combinator[B] { type Context = f.Context }
33+
): Combinator[Combine[A, B]] with {
34+
type Context = f.Context
35+
type Element = (f.Element, s.Element)
36+
extension(self: Combine[A, B]) {
37+
def parse(context: Context): Option[Element] = ???
38+
}
39+
}
40+
41+
extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
42+
if buf.isEmpty then None
43+
else try Some(buf.head) finally buf.remove(0)
44+
45+
@main def hello: Unit = {
46+
val source = (0 to 10).toList
47+
val stream = source.to(mutable.ListBuffer)
48+
49+
val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
50+
val m = Combine(n, n)
51+
52+
val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
53+
val rc: Option[(Int, Int)] = r
54+
// it would be great if this worked
55+
}

tests/pos/infer-tracked-vector.scala

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
object typeparams:
5+
sealed trait Nat
6+
object Z extends Nat
7+
final case class S[N <: Nat]() extends Nat
8+
9+
type Zero = Z.type
10+
type Succ[N <: Nat] = S[N]
11+
12+
sealed trait Fin[N <: Nat]
13+
case class FZero[N <: Nat]() extends Fin[Succ[N]]
14+
case class FSucc[N <: Nat](pred: Fin[N]) extends Fin[Succ[N]]
15+
16+
object Fin:
17+
def zero[N <: Nat]: Fin[Succ[N]] = FZero()
18+
def succ[N <: Nat](i: Fin[N]): Fin[Succ[N]] = FSucc(i)
19+
20+
sealed trait Vec[A, N <: Nat]
21+
case class VNil[A]() extends Vec[A, Zero]
22+
case class VCons[A, N <: Nat](head: A, tail: Vec[A, N]) extends Vec[A, Succ[N]]
23+
24+
object Vec:
25+
def empty[A]: Vec[A, Zero] = VNil()
26+
def cons[A, N <: Nat](head: A, tail: Vec[A, N]): Vec[A, Succ[N]] = VCons(head, tail)
27+
28+
def get[A, N <: Nat](v: Vec[A, N], index: Fin[N]): A = (v, index) match
29+
case (VCons(h, _), FZero()) => h
30+
case (VCons(_, t), FSucc(pred)) => get(t, pred)
31+
32+
def runVec(): Unit =
33+
val v: Vec[Int, Succ[Succ[Succ[Zero]]]] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))
34+
35+
println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
36+
println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
37+
println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
38+
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") // error
39+
40+
// TODO(kπ) check if I can get it to work
41+
// object typemembers:
42+
// sealed trait Nat
43+
// object Z extends Nat
44+
// case class S() extends Nat:
45+
// type N <: Nat
46+
47+
// type Zero = Z.type
48+
// type Succ[N1 <: Nat] = S { type N = N1 }
49+
50+
// sealed trait Fin:
51+
// type N <: Nat
52+
// case class FZero[N1 <: Nat]() extends Fin:
53+
// type N = Succ[N1]
54+
// case class FSucc(tracked val pred: Fin) extends Fin:
55+
// type N = Succ[pred.N]
56+
57+
// object Fin:
58+
// def zero[N1 <: Nat]: Fin { type N = Succ[N1] } = FZero[N1]()
59+
// def succ[N1 <: Nat](i: Fin { type N = N1 }): Fin { type N = Succ[N1] } = FSucc(i)
60+
61+
// sealed trait Vec[A]:
62+
// type N <: Nat
63+
// case class VNil[A]() extends Vec[A]:
64+
// type N = Zero
65+
// case class VCons[A](head: A, tracked val tail: Vec[A]) extends Vec[A]:
66+
// type N = Succ[tail.N]
67+
68+
// object Vec:
69+
// def empty[A]: Vec[A] = VNil()
70+
// def cons[A](head: A, tail: Vec[A]): Vec[A] = VCons(head, tail)
71+
72+
// def get[A](v: Vec[A], index: Fin { type N = v.N }): A = (v, index) match
73+
// case (VCons(h, _), FZero()) => h
74+
// case (VCons(_, t), FSucc(pred)) => get(t, pred)
75+
76+
// // def runVec(): Unit =
77+
// val v: Vec[Int] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))
78+
79+
// println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
80+
// println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
81+
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
82+
// // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}")

tests/pos/infer-tracked.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,37 @@ class F(val x: C):
1010

1111
class G(override val x: C) extends F(x)
1212

13+
class H(val x: C):
14+
type T1 = x.T
15+
val result: T1 = x.foo
16+
17+
class I(val c: C, val t: c.T)
18+
19+
case class J(c: C):
20+
val result: c.T = c.foo
21+
22+
case class K(c: C):
23+
def result[B >: c.T]: B = c.foo
24+
1325
def Test =
1426
val c = new C:
1527
type T = Int
1628
def foo = 42
1729

1830
val f = new F(c)
19-
val i: Int = f.result
31+
val _: Int = f.result
2032

2133
// val g = new G(c)
22-
// val j: Int = g.result
34+
// val _: Int = g.result
35+
36+
val h = new H(c)
37+
val _: Int = h.result
38+
39+
val i = new I(c, c.foo)
40+
val _: Int = i.t
41+
42+
val j = J(c)
43+
val _: Int = j.result
44+
45+
val k = K(c)
46+
val _: Int = k.result

0 commit comments

Comments
 (0)