Skip to content

Commit 8ed6bde

Browse files
committed
Better level-checking in constraints to handle bad bounds
Previously, we enforced level-correctness only when instantiating type variables, but this is not good enough as demonstrated by tests/neg/i8900.scala. The problem is that if we allow level-incorrect bounds, then we might end up reasoning with bad bounds outside of the scope where they are defined. This can lead to level-correct but unsound instantiations. To prevent this, we now enforce level-correctness in constraints at all times, we also introduce `AvoidMap` to share more logic between level-avoidance and symbol-avoidance (see also the added TODO on `TypeOps#avoid`). Note that this implementation is still incomplete: we only check the nestingLevel of NamedTypes, but we also need to check for TypeVars, this will be handled in the next commit.
1 parent 07588ab commit 8ed6bde

File tree

7 files changed

+172
-114
lines changed

7 files changed

+172
-114
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ trait ConstraintHandling {
8181
assert(homogenizeArgs == false)
8282
assert(comparedTypeLambdas == Set.empty)
8383

84+
def nestingLevel(param: TypeParamRef) = constraint.typeVarOfParam(param) match
85+
case tv: TypeVar => tv.nestingLevel
86+
case _ => Int.MaxValue
87+
8488
def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)
8589

8690
def fullLowerBound(param: TypeParamRef)(using Context): Type =
@@ -97,23 +101,57 @@ trait ConstraintHandling {
97101
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
98102
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
99103

104+
/** An approximating map that prevents types nested deeper than maxLevel as
105+
* well as WildcardTypes from leaking into the constraint.
106+
* Note that level-checking is turned off after typer and in uncommitable
107+
* TyperState since these leaks should be safe.
108+
*/
109+
class LevelAvoidMap(topLevelVariance: Int, maxLevel: Int)(using Context) extends TypeOps.AvoidMap:
110+
variance = topLevelVariance
111+
112+
/** Are we allowed to refer to types of the given `level`? */
113+
private def levelOK(level: Int): Boolean =
114+
level <= maxLevel || ctx.isAfterTyper || !ctx.typerState.isCommittable
115+
116+
def toAvoid(tp: NamedType): Boolean =
117+
tp.prefix == NoPrefix && !tp.symbol.isStatic && !levelOK(tp.symbol.nestingLevel)
118+
119+
override def mapWild(t: WildcardType) =
120+
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
121+
else
122+
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds)
123+
tvar
124+
end LevelAvoidMap
125+
126+
/** Approximate `rawBound` if needed to make it a legal bound of `param` by
127+
* avoiding wildcards and types with a level strictly greater than its
128+
* `nestingLevel`.
129+
*
130+
* Note that level-checking must be performed here and cannot be delayed
131+
* until instantiation because if we allow level-incorrect bounds, then we
132+
* might end up reasoning with bad bounds outside of the scope where they are
133+
* defined. This can lead to level-correct but unsound instantiations as
134+
* demonstrated by tests/neg/i8900.scala.
135+
*/
136+
protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type =
137+
// Over-approximate for soundness.
138+
var variance = if isUpper then -1 else 1
139+
// ...unless we can only infer necessary constraints, in which case we
140+
// flip the variance to under-approximate.
141+
if necessaryConstraintsOnly then variance = -variance
142+
143+
val approx = LevelAvoidMap(variance, nestingLevel(param))
144+
approx(rawBound)
145+
end legalBound
146+
100147
protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
101148
if !constraint.contains(param) then true
102149
else if !isUpper && param.occursIn(rawBound) then
103150
// We don't allow recursive lower bounds when defining a type,
104151
// so we shouldn't allow them as constraints either.
105152
false
106153
else
107-
val dropWildcards = new AvoidWildcardsMap:
108-
// Approximate the upper-bound from below and vice-versa
109-
if isUpper then variance = -1
110-
// ...unless we can only infer necessary constraints, in which case we
111-
// flip the variance to under-approximate.
112-
if necessaryConstraintsOnly then variance = -variance
113-
override def mapWild(t: WildcardType) =
114-
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
115-
else newTypeVar(apply(t.effectiveBounds).toBounds)
116-
val bound = dropWildcards(rawBound)
154+
val bound = legalBound(param, rawBound, isUpper)
117155
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
118156
val equalBounds = (if isUpper then lo else hi) eq bound
119157
if equalBounds && !bound.existsPart(_ eq param, StopAt.Static) then

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ final class ProperGadtConstraint private(
7979
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
8080
}
8181

82+
override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type =
83+
// GADT constraints never involve wildcards and are not propagated outside
84+
// the case where they're valid, so no approximating is needed.
85+
rawBound
86+
8287
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
8388
import NameKinds.DepParamName
8489

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 84 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -410,91 +410,109 @@ object TypeOps:
410410
}
411411
}
412412

413-
/** An upper approximation of the given type `tp` that does not refer to any symbol in `symsToAvoid`.
413+
/** An approximating map that drops NamedTypes matching `toAvoid` and wildcard types. */
414+
abstract class AvoidMap(using Context) extends AvoidWildcardsMap:
415+
@threadUnsafe lazy val localParamRefs = util.HashSet[Type]()
416+
417+
def toAvoid(tp: NamedType): Boolean
418+
419+
/** True iff all NamedTypes on this prefix are static */
420+
override def isStaticPrefix(pre: Type)(using Context): Boolean = pre match
421+
case pre: NamedType =>
422+
val sym = pre.currentSymbol
423+
sym.is(Package) || sym.isStatic && isStaticPrefix(pre.prefix)
424+
case _ => true
425+
426+
override def apply(tp: Type): Type = tp match
427+
case tp: TermRef
428+
if toAvoid(tp) =>
429+
tp.info.widenExpr.dealias match {
430+
case info: SingletonType => apply(info)
431+
case info => range(defn.NothingType, apply(info))
432+
}
433+
case tp: TypeRef if toAvoid(tp) =>
434+
tp.info match {
435+
case info: AliasingBounds =>
436+
apply(info.alias)
437+
case TypeBounds(lo, hi) =>
438+
range(atVariance(-variance)(apply(lo)), apply(hi))
439+
case info: ClassInfo =>
440+
range(defn.NothingType, apply(classBound(info)))
441+
case _ =>
442+
emptyRange // should happen only in error cases
443+
}
444+
case tp: ThisType =>
445+
// ThisType is only used inside a class.
446+
// Therefore, either they don't appear in the type to be avoided, or
447+
// it must be a class that encloses the block whose type is to be avoided.
448+
tp
449+
case tp: LazyRef =>
450+
if localParamRefs.contains(tp.ref) then tp
451+
else if isExpandingBounds then emptyRange
452+
else mapOver(tp)
453+
case tl: HKTypeLambda =>
454+
localParamRefs ++= tl.paramRefs
455+
mapOver(tl)
456+
case _ =>
457+
super.apply(tp)
458+
end apply
459+
460+
/** Three deviations from standard derivedSelect:
461+
* 1. We first try a widening conversion to the type's info with
462+
* the original prefix. Since the original prefix is known to
463+
* be a subtype of the returned prefix, this can improve results.
464+
* 2. Then, if the approximation result is a singleton reference C#x.type, we
465+
* replace by the widened type, which is usually more natural.
466+
* 3. Finally, we need to handle the case where the prefix type does not have a member
467+
* named `tp.name` anymmore. In that case, we need to fall back to Bot..Top.
468+
*/
469+
override def derivedSelect(tp: NamedType, pre: Type) =
470+
if (pre eq tp.prefix)
471+
tp
472+
else tryWiden(tp, tp.prefix).orElse {
473+
if (tp.isTerm && variance > 0 && !pre.isSingleton)
474+
apply(tp.info.widenExpr)
475+
else if (upper(pre).member(tp.name).exists)
476+
super.derivedSelect(tp, pre)
477+
else
478+
range(defn.NothingType, defn.AnyType)
479+
}
480+
end AvoidMap
481+
482+
/** An upper approximation of the given type `tp` that does not refer to any symbol in `symsToAvoid`
483+
* and does not contain any WildcardType.
414484
* We need to approximate with ranges:
415485
*
416486
* term references to symbols in `symsToAvoid`,
417487
* term references that have a widened type of which some part refers
418488
* to a symbol in `symsToAvoid`,
419489
* type references to symbols in `symsToAvoid`,
420-
* this types of classes in `symsToAvoid`.
421490
*
422491
* Type variables that would be interpolated to a type that
423492
* needs to be widened are replaced by the widened interpolation instance.
493+
*
494+
* TODO: Could we replace some or all usages of this method by
495+
* `LevelAvoidMap` instead? It would be good to investigate this in details
496+
* but when I tried it, avoidance for inlined trees broke because `TreeMap`
497+
* does not update `ctx.nestingLevel` when entering a block so I'm leaving
498+
* this as Future Work™.
424499
*/
425500
def avoid(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type = {
426-
val widenMap = new ApproximatingTypeMap {
501+
val widenMap = new AvoidMap {
427502
@threadUnsafe lazy val forbidden = symsToAvoid.toSet
428-
@threadUnsafe lazy val localParamRefs = util.HashSet[Type]()
429-
def toAvoid(sym: Symbol) = !sym.isStatic && forbidden.contains(sym)
430-
def partsToAvoid = new NamedPartsAccumulator(tp => toAvoid(tp.symbol))
431-
432-
/** True iff all NamedTypes on this prefix are static */
433-
override def isStaticPrefix(pre: Type)(using Context): Boolean = pre match
434-
case pre: NamedType =>
435-
val sym = pre.currentSymbol
436-
sym.is(Package) || sym.isStatic && isStaticPrefix(pre.prefix)
437-
case _ => true
438-
439-
def apply(tp: Type): Type = tp match
440-
case tp: TermRef
441-
if toAvoid(tp.symbol) || partsToAvoid(Nil, tp.info).nonEmpty =>
442-
tp.info.widenExpr.dealias match {
443-
case info: SingletonType => apply(info)
444-
case info => range(defn.NothingType, apply(info))
445-
}
446-
case tp: TypeRef if toAvoid(tp.symbol) =>
447-
tp.info match {
448-
case info: AliasingBounds =>
449-
apply(info.alias)
450-
case TypeBounds(lo, hi) =>
451-
range(atVariance(-variance)(apply(lo)), apply(hi))
452-
case info: ClassInfo =>
453-
range(defn.NothingType, apply(classBound(info)))
454-
case _ =>
455-
emptyRange // should happen only in error cases
456-
}
457-
case tp: ThisType =>
458-
// ThisType is only used inside a class.
459-
// Therefore, either they don't appear in the type to be avoided, or
460-
// it must be a class that encloses the block whose type is to be avoided.
461-
tp
503+
def toAvoid(tp: NamedType) =
504+
val sym = tp.symbol
505+
!sym.isStatic && forbidden.contains(sym)
506+
507+
override def apply(tp: Type): Type = tp match
462508
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
463509
val lo = TypeComparer.instanceType(
464510
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
465511
val lo1 = apply(lo)
466512
if (lo1 ne lo) lo1 else tp
467-
case tp: LazyRef =>
468-
if localParamRefs.contains(tp.ref) then tp
469-
else if isExpandingBounds then emptyRange
470-
else mapOver(tp)
471-
case tl: HKTypeLambda =>
472-
localParamRefs ++= tl.paramRefs
473-
mapOver(tl)
474513
case _ =>
475-
mapOver(tp)
514+
super.apply(tp)
476515
end apply
477-
478-
/** Three deviations from standard derivedSelect:
479-
* 1. We first try a widening conversion to the type's info with
480-
* the original prefix. Since the original prefix is known to
481-
* be a subtype of the returned prefix, this can improve results.
482-
* 2. Then, if the approximation result is a singleton reference C#x.type, we
483-
* replace by the widened type, which is usually more natural.
484-
* 3. Finally, we need to handle the case where the prefix type does not have a member
485-
* named `tp.name` anymmore. In that case, we need to fall back to Bot..Top.
486-
*/
487-
override def derivedSelect(tp: NamedType, pre: Type) =
488-
if (pre eq tp.prefix)
489-
tp
490-
else tryWiden(tp, tp.prefix).orElse {
491-
if (tp.isTerm && variance > 0 && !pre.isSingleton)
492-
apply(tp.info.widenExpr)
493-
else if (upper(pre).member(tp.name).exists)
494-
super.derivedSelect(tp, pre)
495-
else
496-
range(defn.NothingType, defn.AnyType)
497-
}
498516
}
499517

500518
widenMap(tp)

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4668,7 +4668,6 @@ object Types {
46684668
* @param creatorState The typer state in which the variable was created.
46694669
*/
46704670
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, val nestingLevel: Int) extends CachedProxyType with ValueType {
4671-
46724671
private var currentOrigin = initOrigin
46734672

46744673
def origin: TypeParamRef = currentOrigin
@@ -4709,38 +4708,6 @@ object Types {
47094708
/** Is the variable already instantiated? */
47104709
def isInstantiated(using Context): Boolean = instanceOpt.exists
47114710

4712-
/** Avoid term references in `tp` to parameters or local variables that
4713-
* are nested more deeply than the type variable itself.
4714-
*/
4715-
private def avoidCaptures(tp: Type)(using Context): Type =
4716-
if ctx.isAfterTyper then
4717-
return tp
4718-
val problemSyms = new TypeAccumulator[Set[Symbol]]:
4719-
def apply(syms: Set[Symbol], t: Type): Set[Symbol] = t match
4720-
case ref: NamedType
4721-
// AVOIDANCE TODO: Are there other problematic kinds of references?
4722-
// Our current tests only give us these, but we might need to generalize this.
4723-
if (ref.prefix eq NoPrefix) && ref.symbol.maybeOwner.nestingLevel > nestingLevel =>
4724-
syms + ref.symbol
4725-
case _ =>
4726-
foldOver(syms, t)
4727-
val problems = problemSyms(Set.empty, tp)
4728-
if problems.isEmpty then tp
4729-
else
4730-
val atp = TypeOps.avoid(tp, problems.toList)
4731-
def msg = i"Inaccessible variables captured in instantation of type variable $this.\n$tp was fixed to $atp"
4732-
typr.println(msg)
4733-
val bound = TypeComparer.fullUpperBound(origin)
4734-
if !(atp <:< bound) then
4735-
throw new TypeError(i"$msg,\nbut the latter type does not conform to the upper bound $bound")
4736-
atp
4737-
// AVOIDANCE TODO: This really works well only if variables are instantiated from below
4738-
// If we hit a problematic symbol while instantiating from above, then avoidance
4739-
// will widen the instance type further. This could yield an alias, which would be OK.
4740-
// But it also could yield a true super type which would then fail the bounds check
4741-
// and throw a TypeError. The right thing to do instead would be to avoid "downwards".
4742-
// To do this, we need first test cases for that situation.
4743-
47444711
/** Instantiate variable with given type */
47454712
def instantiateWith(tp: Type)(using Context): Type = {
47464713
assert(tp ne this, i"self instantiation of $origin, constraint = ${ctx.typerState.constraint}")
@@ -4765,7 +4732,7 @@ object Types {
47654732
* is also a singleton type.
47664733
*/
47674734
def instantiate(fromBelow: Boolean)(using Context): Type =
4768-
val tp = avoidCaptures(TypeComparer.instanceType(origin, fromBelow))
4735+
val tp = TypeComparer.instanceType(origin, fromBelow)
47694736
if myInst.exists then // The line above might have triggered instantiation of the current type variable
47704737
myInst
47714738
else

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ object Inferencing {
447447
* We need to take the occurences in `pt` into account because a type
448448
* variable created when typing the current tree might only appear in the
449449
* bounds of a type variable in the expected type, for example when
450-
* `ConstraintHandling#addOneBound` creates type variables when approximating
450+
* `ConstraintHandling#legalBound` creates type variables when approximating
451451
* a bound.
452452
*
453453
* Note: We intentionally use a relaxed version of variance here,

tests/neg/i8900.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
trait Base {
2+
type M
3+
}
4+
trait A {
5+
type M >: Int | String
6+
}
7+
trait B {
8+
type M <: Int & String
9+
}
10+
object Test {
11+
def foo[T](z: T, x: A & B => T): T = z
12+
def foo2[T](z: T, x: T): T = z
13+
14+
def main(args: Array[String]): Unit = {
15+
val x = foo(1, x => (??? : x.M))
16+
val x1: String = x // error (was: ClassCastException)
17+
18+
val a = foo2(1,
19+
if false then
20+
val x: A & B = ???
21+
??? : x.M
22+
else 1
23+
)
24+
25+
val b: String = a // error (was: ClassCastException)
26+
}
27+
}
28+

tests/neg/i8861.scala renamed to tests/run/i8861.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ object Test {
1818
int = vi => vi.i : vi.A,
1919
str = vs => vs.t : vs.A
2020
)
21+
// Used to infer `c.visit[Int & M)]` and error out in the second lambda,
22+
// now infers `c.visit[(Int & M | String & M)]`
2123
def minimalFail[M](c: Container { type A = M }): M = c.visit(
2224
int = vi => vi.i : vi.A,
23-
str = vs => vs.t : vs.A // error // error
25+
str = vs => vs.t : vs.A
2426
)
2527

2628
def main(args: Array[String]): Unit = {
2729
val e: Container { type A = String } = new StrV
2830
println(minimalOk(e)) // this one prints "hello"
29-
println(minimalFail(e)) // this one fails with ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
31+
println(minimalFail(e)) // used to fail with ClassCastException, now prints "hello"
3032
}
31-
}
33+
}

0 commit comments

Comments
 (0)