diff --git a/compiler/src/dotty/tools/dotc/core/NamerOps.scala b/compiler/src/dotty/tools/dotc/core/NamerOps.scala index 75a135826785..8d096913e285 100644 --- a/compiler/src/dotty/tools/dotc/core/NamerOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NamerOps.scala @@ -5,6 +5,7 @@ package core import Contexts.*, Symbols.*, Types.*, Flags.*, Scopes.*, Decorators.*, Names.*, NameOps.* import SymDenotations.{LazyType, SymDenotation}, StdNames.nme import TypeApplications.EtaExpansion +import collection.mutable /** Operations that are shared between Namer and TreeUnpickler */ object NamerOps: @@ -18,6 +19,26 @@ object NamerOps: case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef)) case _ => ctor.owner.typeRef + /** Split dependent class refinements off parent type. Add them to `refinements`, + * unless it is null. + */ + extension (tp: Type) + def separateRefinements(cls: ClassSymbol, refinements: mutable.LinkedHashMap[Name, Type] | Null)(using Context): Type = + tp match + case RefinedType(tp1, rname, rinfo) => + try tp1.separateRefinements(cls, refinements) + finally + if refinements != null then + refinements(rname) = refinements.get(rname) match + case Some(tp) => tp & rinfo + case None => rinfo + case tp @ AnnotatedType(tp1, ann) => + tp.derivedAnnotatedType(tp1.separateRefinements(cls, refinements), ann) + case tp: RecType => + tp.parent.substRecThis(tp, cls.thisType).separateRefinements(cls, refinements) + case tp => + tp + /** If isConstructor, make sure it has at least one non-implicit parameter list * This is done by adding a () in front of a leading old style implicit parameter, * or by adding a () as last -- or only -- parameter list if the constructor has diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index b7a25cb75613..71995bb2f18b 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -1049,7 +1049,7 @@ class TreeUnpickler(reader: TastyReader, } val parentReader = fork val parents = readParents(withArgs = false)(using parentCtx) - val parentTypes = parents.map(_.tpe.dealias) + val parentTypes = parents.map(_.tpe.dealiasKeepAnnots.separateRefinements(cls, null)) if cls.is(JavaDefined) && parentTypes.exists(_.derivesFrom(defn.JavaAnnotationClass)) then cls.setFlag(JavaAnnotation) val self = diff --git a/compiler/src/dotty/tools/dotc/transform/init/Util.scala b/compiler/src/dotty/tools/dotc/transform/init/Util.scala index 70390028e84f..14bcf6fa61bf 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Util.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Util.scala @@ -20,6 +20,7 @@ object Util: def typeRefOf(tp: Type)(using Context): TypeRef = tp.dealias.typeConstructor match case tref: TypeRef => tref + case RefinedType(parent, _, _) => typeRefOf(parent) case hklambda: HKTypeLambda => typeRefOf(hklambda.resType) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index f8ced1c6599a..36ffdd7e5693 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -54,11 +54,12 @@ class Namer { typer: Typer => import untpd.* - val TypedAhead : Property.Key[tpd.Tree] = new Property.Key - val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key - val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key - val SymOfTree : Property.Key[Symbol] = new Property.Key - val AttachedDeriver : Property.Key[Deriver] = new Property.Key + val TypedAhead : Property.Key[tpd.Tree] = new Property.Key + val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key + val ExportForwarders : Property.Key[List[tpd.MemberDef]] = new Property.Key + val ParentRefinements: Property.Key[List[Symbol]] = new Property.Key + val SymOfTree : Property.Key[Symbol] = new Property.Key + val AttachedDeriver : Property.Key[Deriver] = new Property.Key // was `val Deriver`, but that gave shadowing problems with constructor proxies /** A partial map from unexpanded member and pattern defs and to their expansions. @@ -1485,6 +1486,7 @@ class Namer { typer: Typer => /** The type signature of a ClassDef with given symbol */ override def completeInCreationContext(denot: SymDenotation): Unit = { val parents = impl.parents + val parentRefinements = new mutable.LinkedHashMap[Name, Type] /* The type of a parent constructor. Types constructor arguments * only if parent type contains uninstantiated type parameters. @@ -1536,7 +1538,8 @@ class Namer { typer: Typer => val ptype = parentType(parent)(using completerCtx.superCallContext).dealiasKeepAnnots if (cls.isRefinementClass) ptype else { - val pt = checkClassType(ptype, parent.srcPos, + val pt = checkClassType( + ptype.separateRefinements(cls, parentRefinements), parent.srcPos, traitReq = parent ne parents.head, stablePrefixReq = true) if (pt.derivesFrom(cls)) { val addendum = parent match { @@ -1564,6 +1567,21 @@ class Namer { typer: Typer => } } + /** Enter all parent refinements as public class members, unless a definition + * with the same name already exists in the class. + */ + def enterParentRefinementSyms(refinements: List[(Name, Type)]) = + val refinedSyms = mutable.ListBuffer[Symbol]() + for (name, tp) <- refinements do + if decls.lookupEntry(name) == null then + val flags = tp match + case tp: MethodOrPoly => Method | Synthetic | Deferred + case _ => Synthetic | Deferred + refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered + if refinedSyms.nonEmpty then + typr.println(i"parent refinement symbols: ${refinedSyms.toList}") + original.pushAttachment(ParentRefinements, refinedSyms.toList) + /** If `parents` contains references to traits that have supertraits with implicit parameters * add those supertraits in linearization order unless they are already covered by other * parent types. For instance, in @@ -1632,6 +1650,7 @@ class Namer { typer: Typer => cls.invalidateMemberCaches() // we might have checked for a member when parents were not known yet. cls.setNoInitsFlags(parentsKind(parents), untpd.bodyKind(rest)) cls.setStableConstructor() + enterParentRefinementSyms(parentRefinements.toList) processExports(using localCtx) defn.patchStdLibClass(cls) addConstructorProxies(cls) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 2f03c79754e8..bb15f1b5e8e1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -912,7 +912,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if (templ1.parents.isEmpty && isFullyDefined(pt, ForceDegree.flipBottom) && isSkolemFree(pt) && - isEligible(pt.underlyingClassRef(refinementOK = false))) + isEligible(pt.underlyingClassRef(refinementOK = true))) templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil) for case parent: RefTree <- templ1.parents do typedAhead(parent, tree => inferTypeParams(typedType(tree), pt)) @@ -2766,6 +2766,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } } + /** Add all parent refinement symbols as declarations to this class */ + def addParentRefinements(body: List[Tree])(using Context): List[Tree] = + cdef.getAttachment(ParentRefinements) match + case Some(refinedSyms) => + val refinements = refinedSyms.map: sym => + ( if sym.isType then TypeDef(sym.asType) + else if sym.is(Method) then DefDef(sym.asTerm) + else ValDef(sym.asTerm) + ).withSpan(impl.span.startPos) + body ++ refinements + case None => + body + ensureCorrectSuperClass() completeAnnotations(cdef, cls) val constr1 = typed(constr).asInstanceOf[DefDef] @@ -2786,7 +2799,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cdef.withType(UnspecifiedErrorType) else { val dummy = localDummy(cls, impl) - val body1 = addAccessorDefs(cls, typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1) + val body1 = + addParentRefinements( + addAccessorDefs(cls, + typedStats(impl.body, dummy)(using ctx.inClassContext(self1.symbol))._1)) checkNoDoubleDeclaration(cls) val impl1 = cpy.Template(impl)(constr1, parents1, Nil, self1, body1) diff --git a/tests/neg/i0248-inherit-refined.scala b/tests/neg/i0248-inherit-refined.scala index 97b6f5cdab73..afe009e9b28e 100644 --- a/tests/neg/i0248-inherit-refined.scala +++ b/tests/neg/i0248-inherit-refined.scala @@ -1,10 +1,10 @@ object test { class A { type T } type X = A { type T = Int } - class B extends X // error + class B extends X // was error, now OK type Y = A & B class C extends Y // error type Z = A | B class D extends Z // error - abstract class E extends ({ val x: Int }) // error + abstract class E extends ({ val x: Int }) // was error, now OK } diff --git a/tests/neg/parent-refinement-access.check b/tests/neg/parent-refinement-access.check new file mode 100644 index 000000000000..992be56bc43f --- /dev/null +++ b/tests/neg/parent-refinement-access.check @@ -0,0 +1,7 @@ +-- [E164] Declaration Error: tests/neg/parent-refinement-access.scala:4:6 ---------------------------------------------- +4 |trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error + | ^ + | error overriding value x in trait Year2 of type Int; + | value x in trait Gen of type Any has weaker access privileges; it should be public + | (Note that value x in trait Year2 of type Int is abstract, + | and is therefore overridden by concrete value x in trait Gen of type Any) diff --git a/tests/neg/parent-refinement-access.scala b/tests/neg/parent-refinement-access.scala new file mode 100644 index 000000000000..51125321baa7 --- /dev/null +++ b/tests/neg/parent-refinement-access.scala @@ -0,0 +1,4 @@ +trait Gen: + private[Gen] val x: Any = () + +trait Year2(private[Year2] val value: Int) extends (Gen { val x: Int }) // error diff --git a/tests/neg/parent-refinement.check b/tests/neg/parent-refinement.check index 550430bd35a7..83afb80ebf35 100644 --- a/tests/neg/parent-refinement.check +++ b/tests/neg/parent-refinement.check @@ -1,4 +1,25 @@ --- Error: tests/neg/parent-refinement.scala:5:2 ------------------------------------------------------------------------ -5 | with Ordered[Year] { // error - | ^^^^ - | end of toplevel definition expected but 'with' found +-- Error: tests/neg/parent-refinement.scala:10:6 ----------------------------------------------------------------------- +10 |class Bar extends IdOf[Int], (X { type Value = String }) // error + | ^^^ + |class Bar cannot be instantiated since it has a member Value with possibly conflicting bounds Int | String <: ... <: Int & String +-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:14:17 ------------------------------------------------- +14 | val x: Value = 0 // error + | ^ + | Found: (0 : Int) + | Required: Baz.this.Value + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:20:6 -------------------------------------------------- +20 | foo(2) // error + | ^ + | Found: (2 : Int) + | Required: Boolean + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/parent-refinement.scala:16:22 ------------------------------------------------- +16 |val x: IdOf[Int] = Baz() // error + | ^^^^^ + | Found: Baz + | Required: IdOf[Int] + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg/parent-refinement.scala b/tests/neg/parent-refinement.scala index ca2b88a75fd8..a6cdbee92b37 100644 --- a/tests/neg/parent-refinement.scala +++ b/tests/neg/parent-refinement.scala @@ -1,7 +1,20 @@ trait Id { type Value } +trait X { type Value } +type IdOf[T] = Id { type Value = T } + case class Year(value: Int) extends AnyVal - with Id { type Value = Int } - with Ordered[Year] { // error + with (Id { type Value = Int }) + with Ordered[Year] + +class Bar extends IdOf[Int], (X { type Value = String }) // error + +class Baz extends IdOf[Int]: + type Value = String + val x: Value = 0 // error + +val x: IdOf[Int] = Baz() // error -} \ No newline at end of file +object Clash extends ({ def foo(x: Int): Int }): + def foo(x: Boolean): Int = 1 + foo(2) // error diff --git a/tests/pos/parent-refinement.scala b/tests/pos/parent-refinement.scala new file mode 100644 index 000000000000..5b435b656133 --- /dev/null +++ b/tests/pos/parent-refinement.scala @@ -0,0 +1,46 @@ +class A +class B extends A +class C extends B + +trait Id { type Value } +type IdOf[T] = Id { type Value = T } +trait X { type Value } + +case class Year(value: Int) extends IdOf[Int]: + val x: Value = 2 + +type Between[Lo, Hi] = X { type Value >: Lo <: Hi } + +class Foo() extends IdOf[B], Between[C, A]: + val x: Value = B() + +trait Bar extends IdOf[Int], (X { type Value = String }) + +class Baz extends IdOf[Int]: + type Value = String + val x: Value = "" + +trait Gen: + type T + val x: T + +type IntInst = Gen: + type T = Int + val x: 0 + +trait IntInstTrait extends IntInst + +abstract class IntInstClass extends IntInstTrait, IntInst + +object obj1 extends IntInstTrait: + val x = 0 + +object obj2 extends IntInstClass: + val x = 0 + +def main = + val x: obj1.T = 2 - obj2.x + val y: obj2.T = 2 - obj1.x + + + diff --git a/tests/pos/typeclasses.scala b/tests/pos/typeclasses.scala index 07fe5a31ce5d..83a424ab92a8 100644 --- a/tests/pos/typeclasses.scala +++ b/tests/pos/typeclasses.scala @@ -1,8 +1,5 @@ class Common: - // this should go in Predef - infix type at [A <: { type This}, B] = A { type This = B } - trait Ord: type This extension (x: This) @@ -26,41 +23,23 @@ class Common: extension [A](x: This[A]) def flatMap[B](f: A => This[B]): This[B] def map[B](f: A => B) = x.flatMap(f `andThen` pure) + + infix type is[A <: AnyKind, B <: {type This <: AnyKind}] = B { type This = A } + end Common object Instances extends Common: -/* - instance Int: Ord as intOrd with - extension (x: Int) - def compareTo(y: Int) = - if x < y then -1 - else if x > y then +1 - else 0 -*/ - given intOrd: Ord with + given intOrd: (Int is Ord) with type This = Int extension (x: Int) def compareTo(y: Int) = if x < y then -1 else if x > y then +1 else 0 -/* - instance List[T: Ord]: Ord as listOrd with - extension (xs: List[T]) def compareTo(ys: List[T]): Int = (xs, ys) match - case (Nil, Nil) => 0 - case (Nil, _) => -1 - case (_, Nil) => +1 - case (x :: xs1, y :: ys1) => - val fst = x.compareTo(y) - if (fst != 0) fst else xs1.compareTo(ys1) -*/ - // Proposed short syntax: - // given listOrd[T: Ord as ord]: Ord at T with - given listOrd[T](using ord: Ord { type This = T}): Ord with - type This = List[T] + given listOrd[T](using ord: T is Ord): (List[T] is Ord) with extension (xs: List[T]) def compareTo(ys: List[T]): Int = (xs, ys) match case (Nil, Nil) => 0 case (Nil, _) => -1 @@ -70,32 +49,18 @@ object Instances extends Common: if (fst != 0) fst else xs1.compareTo(ys1) end listOrd -/* - instance List: Monad as listMonad with + given listMonad: (List is Monad) with extension [A](xs: List[A]) def flatMap[B](f: A => List[B]): List[B] = xs.flatMap(f) def pure[A](x: A): List[A] = List(x) -*/ - given listMonad: Monad with - type This[A] = List[A] - extension [A](xs: List[A]) def flatMap[B](f: A => List[B]): List[B] = - xs.flatMap(f) - def pure[A](x: A): List[A] = - List(x) -/* - type Reader[Ctx] = X =>> Ctx => X - instance Reader[Ctx: _]: Monad as readerMonad with - extension [A](r: Ctx => A) def flatMap[B](f: A => Ctx => B): Ctx => B = - ctx => f(r(ctx))(ctx) - def pure[A](x: A): Ctx => A = - ctx => x -*/ + type Reader[Ctx] = [X] =>> Ctx => X - given readerMonad[Ctx]: Monad with - type This[X] = Ctx => X + //given [Ctx] => Reader[Ctx] is Monad as readerMonad: + + given readerMonad[Ctx]: (Reader[Ctx] is Monad) with extension [A](r: Ctx => A) def flatMap[B](f: A => Ctx => B): Ctx => B = ctx => f(r(ctx))(ctx) def pure[A](x: A): Ctx => A = @@ -110,29 +75,17 @@ object Instances extends Common: def second = xs.tail.head def third = xs.tail.tail.head - //Proposed short syntax: - //extension [M: Monad as m, A](xss: M[M[A]]) - // def flatten: M[A] = - // xs.flatMap(identity) - extension [M, A](using m: Monad)(xss: m.This[m.This[A]]) def flatten: m.This[A] = xss.flatMap(identity) - // Proposed short syntax: - //def maximum[T: Ord](xs: List[T]: T = - def maximum[T](xs: List[T])(using Ord at T): T = + def maximum[T](xs: List[T])(using T is Ord): T = xs.reduceLeft((x, y) => if (x < y) y else x) - // Proposed short syntax: - // def descending[T: Ord as asc]: Ord at T = new Ord: - def descending[T](using asc: Ord at T): Ord at T = new Ord: - type This = T + def descending[T](using asc: T is Ord): T is Ord = new: extension (x: T) def compareTo(y: T) = asc.compareTo(y)(x) - // Proposed short syntax: - // def minimum[T: Ord](xs: List[T]) = - def minimum[T](xs: List[T])(using Ord at T) = + def minimum[T](xs: List[T])(using T is Ord) = maximum(xs)(using descending) def test(): Unit = @@ -177,10 +130,10 @@ instance Sheep: Animal with override def talk(): Unit = println(s"$name pauses briefly... $noise") */ +import Instances.is // Implement the `Animal` trait for `Sheep`. -given Animal with - type This = Sheep +given (Sheep is Animal) with def apply(name: String) = Sheep(name) extension (self: This) def name: String = self.name