diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 763b71619de8..19570f13c519 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -27,6 +27,7 @@ import scala.collection.immutable.ListSet import scala.collection.mutable import scala.annotation.tailrec import scala.annotation.constructorOnly +import dotty.tools.dotc.core.Flags.AbstractOrTrait /** Check initialization safety of static objects * @@ -203,6 +204,7 @@ object Objects: /** * Represents a lambda expression + * @param klass The enclosing class of the anonymous function's creation site */ case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement: def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")" @@ -599,6 +601,26 @@ object Objects: case _ => a + def filterType(tpe: Type)(using Context): Value = + tpe match + case t @ SAMType(_, _) if a.isInstanceOf[Fun] => a // if tpe is SAMType and a is Fun, allow it + case _ => + val baseClasses = tpe.baseClasses + if baseClasses.isEmpty then a + else filterClass(baseClasses.head) // could have called ClassSymbol, but it does not handle OrType and AndType + + def filterClass(sym: Symbol)(using Context): Value = + if !sym.isClass then a + else + val klass = sym.asClass + a match + case Cold => Cold + case ref: Ref => if ref.klass.isSubClass(klass) then ref else Bottom + case ValueSet(values) => values.map(v => v.filterClass(klass)).join + case arr: OfArray => if defn.ArrayClass.isSubClass(klass) then arr else Bottom + case fun: Fun => + if klass.isOneOf(AbstractOrTrait) && klass.baseClasses.exists(defn.isFunctionClass) then fun else Bottom + extension (value: Ref | Cold.type) def widenRefOrCold(height : Int)(using Context) : Ref | Cold.type = value.widen(height).asInstanceOf[ThisValue] @@ -617,7 +639,7 @@ object Objects: * @param needResolve Whether the target of the call needs resolution? */ def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { - value match + value.filterClass(meth.owner) match case Cold => report.warning("Using cold alias. " + Trace.show, Trace.position) Bottom @@ -733,7 +755,6 @@ object Objects: * @param args Arguments of the constructor call (all parameter blocks flatten to a list). */ def callConstructor(value: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { - value match case ref: Ref => if ctor.hasSource then @@ -768,7 +789,7 @@ object Objects: * @param needResolve Whether the target of the selection needs resolution? */ def select(value: Value, field: Symbol, receiver: Type, needResolve: Boolean = true): Contextual[Value] = log("select " + field.show + ", this = " + value.show, printer, (_: Value).show) { - value match + value.filterClass(field.owner) match case Cold => report.warning("Using cold alias", Trace.position) Bottom @@ -839,12 +860,12 @@ object Objects: * @param rhsTyp The type of the right-hand side. */ def assign(lhs: Value, field: Symbol, rhs: Value, rhsTyp: Type): Contextual[Value] = log("Assign" + field.show + " of " + lhs.show + ", rhs = " + rhs.show, printer, (_: Value).show) { - lhs match + lhs.filterClass(field.owner) match case fun: Fun => report.warning("[Internal error] unexpected tree in assignment, fun = " + fun.code.show + Trace.show, Trace.position) case arr: OfArray => - report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + Trace.show, Trace.position) + report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + " field = " + field + Trace.show, Trace.position) case Cold => report.warning("Assigning to cold aliases is forbidden. " + Trace.show, Trace.position) @@ -876,8 +897,7 @@ object Objects: * @param args The arguments passsed to the constructor. */ def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) { - outer match - + outer.filterClass(klass.owner) match case _ : Fun | _: OfArray => report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position) Bottom @@ -1091,6 +1111,9 @@ object Objects: instantiate(outer, cls, ctor, args) } + case TypeCast(elem, tpe) => + eval(elem, thisV, klass).filterType(tpe) + case Apply(ref, arg :: Nil) if ref.symbol == defn.InitRegionMethod => val regions2 = Regions.extend(expr.sourcePos) if Regions.exists(expr.sourcePos) then @@ -1549,7 +1572,7 @@ object Objects: report.warning("The argument should be a constant integer value", arg) res.widen(1) case _ => - res.widen(1) + if res.isInstanceOf[Fun] then res.widen(2) else res.widen(1) argInfos += ArgInfo(widened, trace.add(arg.tree), arg.tree) } diff --git a/compiler/src/dotty/tools/dotc/transform/init/Util.scala b/compiler/src/dotty/tools/dotc/transform/init/Util.scala index 70390028e84f..756fd1a0a8e7 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Util.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Util.scala @@ -78,6 +78,13 @@ object Util: case _ => None + object TypeCast: + def unapply(tree: Tree)(using Context): Option[(Tree, Type)] = + tree match + case TypeApply(Select(qual, _), typeArgs) if tree.symbol.isTypeCast => + Some(qual, typeArgs.head.tpe) + case _ => None + def resolve(cls: ClassSymbol, sym: Symbol)(using Context): Symbol = log("resove " + cls + ", " + sym, printer, (_: Symbol).show): if sym.isEffectivelyFinal then sym else sym.matchingMember(cls.appliedRef) diff --git a/tests/init-global/neg/TypeCast.scala b/tests/init-global/neg/TypeCast.scala new file mode 100644 index 000000000000..55447e9df4e2 --- /dev/null +++ b/tests/init-global/neg/TypeCast.scala @@ -0,0 +1,18 @@ +object A { + val f: Int = 10 + def m() = f +} +object B { + val f: Int = g() + def g(): Int = f // error +} +object C { + val a: A.type | B.type = if ??? then A else B + def cast[T](a: Any): T = a.asInstanceOf[T] + val c: A.type = cast[A.type](a) // abstraction for c is {A, B} + val d = c.f // treat as c.asInstanceOf[owner of f].f + val e = c.m() // treat as c.asInstanceOf[owner of f].m() + val c2: B.type = cast[B.type](a) + val g = c2.f // no error here +} + diff --git a/tests/init-global/pos/TypeCast1.scala b/tests/init-global/pos/TypeCast1.scala new file mode 100644 index 000000000000..e9881c6f5e4d --- /dev/null +++ b/tests/init-global/pos/TypeCast1.scala @@ -0,0 +1,9 @@ +class A: + class B(val b: Int) + +object O: + val o: A | Array[Int] = new Array[Int](10) + o match + case a: A => new a.B(10) + case arr: Array[Int] => arr(5) + diff --git a/tests/init-global/pos/TypeCast2.scala b/tests/init-global/pos/TypeCast2.scala new file mode 100644 index 000000000000..e18c8ffca5d1 --- /dev/null +++ b/tests/init-global/pos/TypeCast2.scala @@ -0,0 +1,9 @@ +class A: + class B(val b: Int) + +object O: + val o: A | (Int => Int) = (x: Int) => x + 1 + o match + case a: A => new a.B(10) + case f: (_ => _) => f.asInstanceOf[Int => Int](5) + diff --git a/tests/init-global/pos/TypeCast3.scala b/tests/init-global/pos/TypeCast3.scala new file mode 100644 index 000000000000..08197790edd6 --- /dev/null +++ b/tests/init-global/pos/TypeCast3.scala @@ -0,0 +1,8 @@ +class A: + var x: Int = 10 + +object O: + val o: A | (Int => Int) = (x: Int) => x + 1 + o match + case a: A => a.x = 20 + case f: (_ => _) => f.asInstanceOf[Int => Int](5) diff --git a/tests/init-global/pos/TypeCast4.scala b/tests/init-global/pos/TypeCast4.scala new file mode 100644 index 000000000000..8b65bc775cc2 --- /dev/null +++ b/tests/init-global/pos/TypeCast4.scala @@ -0,0 +1,9 @@ +class A: + var x: Int = 10 + +object O: + val o: A | Array[Int] = new Array[Int](10) + o match + case a: A => a.x = 20 + case arr: Array[Int] => arr(5) + diff --git a/tests/init-global/pos/i18882.scala b/tests/init-global/pos/i18882.scala new file mode 100644 index 000000000000..0a1ea5309a58 --- /dev/null +++ b/tests/init-global/pos/i18882.scala @@ -0,0 +1,15 @@ +class A: + var a = 20 + +class B: + var b = 20 + +object O: + val o: A | B = new A + if o.isInstanceOf[A] then + o.asInstanceOf[A].a += 1 + else + o.asInstanceOf[B].b += 1 // o.asInstanceOf[B] is treated as bottom + o match + case o: A => o.a += 1 + case o: B => o.b += 1