diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 33b78e9fe945..45b1c56ab4be 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -4,6 +4,7 @@ package dotc import core._ import Contexts._ import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import parsing.Parser import Phases.Phase import transform._ @@ -78,6 +79,8 @@ class Compiler { new RefChecks, // Various checks mostly related to abstract members and overriding new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher) :: // Compile pattern matches + List(new PreRecheck) :: // Preparations for check captures phase, enabled under -Ycc + List(new CheckCaptures) :: // Check captures, enabled under -Ycc List(new ElimOpaque, // Turn opaque into normal aliases new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only) new ExplicitOuter, // Add accessors to outer classes from nested ones. @@ -101,8 +104,6 @@ class Compiler { new TupleOptimizations, // Optimize generic operations on tuples new LetOverApply, // Lift blocks from receivers of applications new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. - List(new PreRecheck) :: // Preparations for recheck phase, enabled under -Yrecheck - List(new TestRecheck) :: // Test rechecking, enabled under -Yrecheck List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index b9552d97fca7..32b7b2feaeb3 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -20,9 +20,7 @@ import reporting.{Reporter, Suppression, Action} import reporting.Diagnostic import reporting.Diagnostic.Warning import rewrites.Rewrites - import profile.Profiler -import printing.XprintMode import parsing.Parsers.Parser import parsing.JavaParsers.JavaParser import typer.ImplicitRunInfo @@ -328,7 +326,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint val fusedPhase = ctx.base.fusedContaining(prevPhase) val echoHeader = f"[[syntax trees at end of $fusedPhase%25s]] // ${unit.source}" val tree = if ctx.isAfterTyper then unit.tpdTree else unit.untpdTree - val treeString = tree.show(using ctx.withProperty(XprintMode, Some(()))) + val treeString = fusedPhase.show(tree) last match { case SomePrintedTree(phase, lastTreeString) if lastTreeString == treeString => diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index a71fc3d40e92..5e7e6bd57c29 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1752,6 +1752,9 @@ object desugar { flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) case ext: ExtMethods => Block(List(ext), Literal(Constant(())).withSpan(ext.span)) + case CapturingTypeTree(refs, parent) => + val annot = New(scalaDot(tpnme.retains), List(refs)) + Annotated(parent, annot) } desugared.withSpan(tree.span) } @@ -1890,6 +1893,8 @@ object desugar { case _ => traverseChildren(tree) } }.traverse(expr) + case CapturingTypeTree(refs, parent) => + collect(parent) case _ => } collect(tree) diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 5bdd18705051..267c6b114a8c 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -260,16 +260,10 @@ object Trees { /** Tree's denotation can be derived from its type */ abstract class DenotingTree[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends Tree[T] { type ThisTree[-T >: Untyped] <: DenotingTree[T] - override def denot(using Context): Denotation = typeOpt match { + override def denot(using Context): Denotation = typeOpt.stripped match case tpe: NamedType => tpe.denot case tpe: ThisType => tpe.cls.denot - case tpe: AnnotatedType => tpe.stripAnnots match { - case tpe: NamedType => tpe.denot - case tpe: ThisType => tpe.cls.denot - case _ => NoDenotation - } case _ => NoDenotation - } } /** Tree's denot/isType/isTerm properties come from a subtree diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 40467dc5be3f..b9960cbb4652 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -147,6 +147,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case Floating } + /** {x1, ..., xN} T (only relevant under -Ycc) */ + case class CapturingTypeTree(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** Short-lived usage in typer, does not need copy/transform/fold infrastructure */ case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree @@ -458,7 +461,11 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def AppliedTypeTree(tpt: Tree, arg: Tree)(implicit src: SourceFile): AppliedTypeTree = AppliedTypeTree(tpt, arg :: Nil) - def TypeTree(tpe: Type)(using Context): TypedSplice = TypedSplice(TypeTree().withTypeUnchecked(tpe)) + def TypeTree(tpe: Type)(using Context): TypedSplice = + TypedSplice(TypeTree().withTypeUnchecked(tpe)) + + def InferredTypeTree(tpe: Type)(using Context): TypedSplice = + TypedSplice(new InferredTypeTree().withTypeUnchecked(tpe)) def unitLiteral(implicit src: SourceFile): Literal = Literal(Constant(())) @@ -646,6 +653,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: Number if (digits == tree.digits) && (kind == tree.kind) => tree case _ => finalize(tree, untpd.Number(digits, kind)) } + def CapturingTypeTree(tree: Tree)(refs: List[Tree], parent: Tree)(using Context): Tree = tree match + case tree: CapturingTypeTree if (refs eq tree.refs) && (parent eq tree.parent) => tree + case _ => finalize(tree, untpd.CapturingTypeTree(refs, parent)) + def TypedSplice(tree: Tree)(splice: tpd.Tree)(using Context): ProxyTree = tree match { case tree: TypedSplice if splice `eq` tree.splice => tree case _ => finalize(tree, untpd.TypedSplice(splice)(using ctx)) @@ -711,6 +722,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { tree case MacroTree(expr) => cpy.MacroTree(tree)(transform(expr)) + case CapturingTypeTree(refs, parent) => + cpy.CapturingTypeTree(tree)(transform(refs), transform(parent)) case _ => super.transformMoreCases(tree) } @@ -772,6 +785,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, splice) case MacroTree(expr) => this(x, expr) + case CapturingTypeTree(refs, parent) => + this(this(x, refs), parent) case _ => super.foldMoreCases(x, tree) } diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala new file mode 100644 index 000000000000..5f73b50a6bbe --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala @@ -0,0 +1,63 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.Trees.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import printing.Printer +import printing.Texts.Text + + +case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation: + import CaptureAnnotation.* + import tpd.* + + override def tree(using Context) = + val elems = refs.elems.toList.map { + case cr: TermRef => ref(cr) + case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr) + case cr: ThisType => This(cr.cls) + } + val arg = repeated(elems, TypeTree(defn.AnyType)) + New(symbol.typeRef, arg :: Nil) + + override def symbol(using Context) = defn.RetainsAnnot + + override def derivedAnnotation(tree: Tree)(using Context): Annotation = + unsupported("derivedAnnotation(Tree)") + + def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation = + if (this.refs eq refs) && (this.boxed == boxed) then this + else CaptureAnnotation(refs, boxed) + + override def sameAnnotation(that: Annotation)(using Context): Boolean = that match + case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2 + case _ => false + + override def mapWith(tp: TypeMap)(using Context) = + val elems = refs.elems.toList + val elems1 = elems.mapConserve(tp) + if elems1 eq elems then this + else if elems1.forall(_.isInstanceOf[CaptureRef]) + then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed) + else EmptyAnnotation + + override def refersToParamOf(tl: TermLambda)(using Context): Boolean = + refs.elems.exists { + case TermParamRef(tl1, _) => tl eq tl1 + case _ => false + } + + override def toText(printer: Printer): Text = refs.toText(printer) + + override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0) + + override def eql(that: Annotation) = that match + case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed) + case _ => false + +end CaptureAnnotation diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala new file mode 100644 index 000000000000..09064314b1bf --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -0,0 +1,82 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import util.Property.Key +import tpd.* + +private val Captures: Key[CaptureSet] = Key() +private val IsBoxed: Key[Unit] = Key() + +def retainedElems(tree: Tree)(using Context): List[Tree] = tree match + case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems + case _ => Nil + +extension (tree: Tree) + + def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef] + + def toCaptureSet(using Context): CaptureSet = + tree.getAttachment(Captures) match + case Some(refs) => refs + case None => + val refs = CaptureSet(retainedElems(tree).map(_.toCaptureRef)*) + .showing(i"toCaptureSet $tree --> $result", capt) + tree.putAttachment(Captures, refs) + refs + + def isBoxedCapturing(using Context): Boolean = + tree.hasAttachment(IsBoxed) + + def setBoxedCapturing()(using Context): Unit = + tree.putAttachment(IsBoxed, ()) + +extension (tp: Type) + + def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match + case CapturingType(p, r, b) => + if (parent eq p) && (refs eq r) then tp + else CapturingType(parent, refs, b) + + /** If this is type variable instantiated or upper bounded with a capturing type, + * the capture set associated with that type. Extended to and-or types and + * type proxies in the obvious way. If a term has a type with a boxed captureset, + * that captureset counts towards the capture variables of the envirionment. + */ + def boxedCaptured(using Context): CaptureSet = + def getBoxed(tp: Type): CaptureSet = tp match + case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty + case tp: TypeProxy => getBoxed(tp.superType) + case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2) + case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2) + case _ => CaptureSet.empty + getBoxed(tp) + + def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty + + def canHaveInferredCapture(using Context): Boolean = tp match + case tp: TypeRef if tp.symbol.isClass => + !tp.symbol.isValueClass && tp.symbol != defn.AnyClass + case _: TypeVar | _: TypeParamRef => + false + case tp: TypeProxy => + tp.superType.canHaveInferredCapture + case tp: AndType => + tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture + case tp: OrType => + tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture + case _ => + false + + def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match + case CapturingType(parent, _, _) => + parent.stripCapturing + case atd @ AnnotatedType(parent, annot) => + atd.derivedAnnotatedType(parent.stripCapturing, annot) + case _ => + tp diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala new file mode 100644 index 000000000000..f8ca2f87e3c5 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -0,0 +1,577 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Flags.*, Contexts.*, Decorators.* +import config.Printers.capt +import Annotations.Annotation +import annotation.threadUnsafe +import annotation.constructorOnly +import annotation.internal.sharable +import reporting.trace +import printing.{Showable, Printer} +import printing.Texts.* +import util.{SimpleIdentitySet, Property} +import util.common.alwaysTrue +import scala.collection.mutable + +/** A class for capture sets. Capture sets can be constants or variables. + * Capture sets support inclusion constraints <:< where <:< is subcapturing. + * They also allow mapping with arbitrary functions from elements to capture sets, + * by supporting a monadic flatMap operation. That is, constraints can be + * of one of the following forms + * + * cs1 <:< cs2 + * cs1 = ∪ {f(x) | x ∈ cs2} + * + * where the `f`s are arbitrary functions from capture references to capture sets. + * We call the resulting constraint system "monadic set constraints". + */ +sealed abstract class CaptureSet extends Showable: + import CaptureSet.* + + /** The elements of this capture set. For capture variables, + * the elements known so far. + */ + def elems: Refs + + /** Is this capture set constant (i.e. not an unsolved capture variable)? + * Solved capture variables count as constant. + */ + def isConst: Boolean + + /** Is this capture set always empty? For capture veraiables, returns + * always false + */ + def isAlwaysEmpty: Boolean + + /** Is this capture set definitely non-empty? */ + final def isNotEmpty: Boolean = !elems.isEmpty + + /** Cast to variable. @pre: @isConst */ + def asVar: Var = + assert(!isConst) + asInstanceOf[Var] + + /** Add new elements to this capture set if allowed. + * @pre `newElems` is not empty and does not overlap with `this.elems`. + * Constant capture sets never allow to add new elements. + * Variables allow it if and only if the new elements can be included + * in all their supersets. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if elements were added, or a conflicting + * capture set that prevents addition otherwise. + */ + protected def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult + + /** If this is a variable, add `cs` as a super set */ + protected def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult + + /** If `cs` is a variable, add this capture set as one of its super sets */ + protected def addSub(cs: CaptureSet)(using Context): this.type = + cs.addSuper(this)(using ctx, UnrecordedState) + this + + /** Try to include all references of `elems` that are not yet accounted by this + * capture set. Inclusion is via `addNewElems`. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if all unaccounted elements could be added, + * capture set that prevents addition otherwise. + */ + protected final def tryInclude(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val unaccounted = elems.filter(!accountsFor(_)) + if unaccounted.isEmpty then CompareResult.OK + else addNewElems(unaccounted, origin) + + protected final def tryInclude(elem: CaptureRef, origin: CaptureSet)(using Context, VarState): CompareResult = + if accountsFor(elem) then CompareResult.OK + else addNewElems(elem.singletonCaptureSet.elems, origin) + + extension (x: CaptureRef) private def subsumes(y: CaptureRef) = + (x eq y) + || y.match + case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ? + case _ => false + + /** {x} <:< this where <:< is subcapturing, but treating all variables + * as frozen. + */ + def accountsFor(x: CaptureRef)(using ctx: Context): Boolean = + reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) { + elems.exists(_.subsumes(x)) + || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK + } + + /** The subcapturing test */ + final def subCaptures(that: CaptureSet, frozen: Boolean)(using Context): CompareResult = + subCaptures(that)(using ctx, if frozen then FrozenState else VarState()) + + private def subCaptures(that: CaptureSet)(using Context, VarState): CompareResult = + def recur(elems: List[CaptureRef]): CompareResult = elems match + case elem :: elems1 => + var result = that.tryInclude(elem, this) + if !result.isOK && !elem.isRootCapability && summon[VarState] != FrozenState then + result = elem.captureSetOfInfo.subCaptures(that) + if result.isOK then + recur(elems1) + else + varState.abort() + result + case Nil => + addSuper(that) + recur(elems.toList) + .showing(i"subcaptures $this <:< $that = ${result.show}", capt) + + def =:= (that: CaptureSet)(using Context): Boolean = + this.subCaptures(that, frozen = true).isOK + && that.subCaptures(this, frozen = true).isOK + + /** The smallest capture set (via <:<) that is a superset of both + * `this` and `that` + */ + def ++ (that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then that + else if that.subCaptures(this, frozen = true).isOK then this + else if this.isConst && that.isConst then Const(this.elems ++ that.elems) + else Var(this.elems ++ that.elems).addSub(this).addSub(that) + + /** The smallest superset (via <:<) of this capture set that also contains `ref`. + */ + def + (ref: CaptureRef)(using Context): CaptureSet = + this ++ ref.singletonCaptureSet + + /** The largest capture set (via <:<) that is a subset of both `this` and `that` + */ + def **(that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then this + else if that.subCaptures(this, frozen = true).isOK then that + else if this.isConst && that.isConst then Const(elems.intersect(that.elems)) + else if that.isConst then Intersected(this.asVar, that) + else Intersected(that.asVar, this) + + def -- (that: CaptureSet.Const)(using Context): CaptureSet = + val elems1 = elems.filter(!that.accountsFor(_)) + if elems1.size == elems.size then this + else if this.isConst then Const(elems1) + else Diff(asVar, that) + + def - (ref: CaptureRef)(using Context): CaptureSet = + this -- ref.singletonCaptureSet + + def filter(p: CaptureRef => Boolean)(using Context): CaptureSet = + if this.isConst then Const(elems.filter(p)) + else Filtered(asVar, p) + + /** capture set obtained by applying `f` to all elements of the current capture set + * and joining the results. If the current capture set is a variable, the same + * transformation is applied to all future additions of new elements. + */ + def map(tm: TypeMap)(using Context): CaptureSet = tm match + case tm: BiTypeMap => + val mappedElems = elems.map(tm.forward) + if isConst then Const(mappedElems) + else BiMapped(asVar, tm, mappedElems) + case _ => + val mapped = mapRefs(elems, tm, tm.variance) + if isConst then mapped + else Mapped(asVar, tm, tm.variance, mapped) + + def substParams(tl: BindingType, to: List[Type])(using Context) = + map(Substituters.SubstParamsMap(tl, to)) + + /** An upper approximation of this capture set. This is the set itself + * except for real (non-mapped, non-filtered) capture set variables, where + * it is the intersection of all upper approximations of known supersets + * of the variable. + * The upper approximation is meaningful only if it is constant. If not, + * `upperApprox` can return an arbitrary capture set variable. + */ + protected def upperApprox(origin: CaptureSet)(using Context): CaptureSet + + protected def propagateSolved()(using Context): Unit = () + + def toRetainsTypeArg(using Context): Type = + assert(isConst) + ((NoType: Type) /: elems) ((tp, ref) => + if tp.exists then OrType(tp, ref, soft = false) else ref) + + def toRegularAnnotation(using Context): Annotation = + Annotation(CaptureAnnotation(this, boxed = false).tree) + + override def toText(printer: Printer): Text = + Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") + +object CaptureSet: + type Refs = SimpleIdentitySet[CaptureRef] + type Vars = SimpleIdentitySet[Var] + type Deps = SimpleIdentitySet[CaptureSet] + + /** If set to `true`, capture stack traces that tell us where sets are created */ + private final val debugSets = false + + private val emptySet = SimpleIdentitySet.empty + @sharable private var varId = 0 + + val empty: CaptureSet.Const = Const(emptySet) + + /** The universal capture set `{*}` */ + def universal(using Context): CaptureSet = + defn.captureRoot.termRef.singletonCaptureSet + + /** Used as a recursion brake */ + @sharable private[dotc] val Pending = Const(SimpleIdentitySet.empty) + + def apply(elems: CaptureRef*)(using Context): CaptureSet.Const = + if elems.isEmpty then empty + else Const(SimpleIdentitySet(elems.map(_.normalizedRef)*)) + + def apply(elems: Refs)(using Context): CaptureSet.Const = + if elems.isEmpty then empty else Const(elems) + + class Const private[CaptureSet] (val elems: Refs) extends CaptureSet: + assert(elems != null) + def isConst = true + def isAlwaysEmpty = elems.isEmpty + + def addNewElems(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState) = CompareResult.OK + + def upperApprox(origin: CaptureSet)(using Context): CaptureSet = this + + override def toString = elems.toString + end Const + + class Var(initialElems: Refs = emptySet) extends CaptureSet: + val id = + varId += 1 + varId + + private var isSolved: Boolean = false + + var elems: Refs = initialElems + var deps: Deps = emptySet + def isConst = isSolved + def isAlwaysEmpty = false + + private def recordElemsState()(using VarState): Boolean = + varState.getElems(this) match + case None => varState.putElems(this, elems) + case _ => true + + private[CaptureSet] def recordDepsState()(using VarState): Boolean = + varState.getDeps(this) match + case None => varState.putDeps(this, deps) + case _ => true + + def resetElems()(using state: VarState): Unit = + elems = state.elems(this) + + def resetDeps()(using state: VarState): Unit = + deps = state.deps(this) + + def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if !isConst && recordElemsState() then + elems ++= newElems + // assert(id != 2 || elems.size != 2, this) + (CompareResult.OK /: deps) { (r, dep) => + r.andAlso(dep.tryInclude(newElems, this)) + } + else + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult = + if (cs eq this) || cs.elems.contains(defn.captureRoot.termRef) || isConst then + CompareResult.OK + else if recordDepsState() then + deps += cs + CompareResult.OK + else + CompareResult.fail(this) + + private var computingApprox = false + + final def upperApprox(origin: CaptureSet)(using Context): CaptureSet = + if computingApprox then universal + else if isConst then this + else + computingApprox = true + try computeApprox(origin).ensuring(_.isConst) + finally computingApprox = false + + protected def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + (universal /: deps) { (acc, sup) => acc ** sup.upperApprox(this) } + + def solve(variance: Int)(using Context): Unit = + if variance < 0 && !isConst then + val approx = upperApprox(empty) + //println(i"solving var $this $approx ${approx.isConst} deps = ${deps.toList}") + if approx.isConst then + val newElems = approx.elems -- elems + if newElems.isEmpty || addNewElems(newElems, empty)(using ctx, VarState()).isOK then + markSolved() + + def markSolved()(using Context): Unit = + isSolved = true + deps.foreach(_.propagateSolved()) + + protected def ids(using Context): String = + val trail = this.match + case dv: DerivedVar => dv.source.ids + case _ => "" + s"$id${getClass.getSimpleName.take(1)}$trail" + + override def toText(printer: Printer): Text = inContext(printer.printerContext) { + for vars <- ctx.property(ShownVars) do vars += this + super.toText(printer) ~ (Str(ids) provided !isConst && ctx.settings.YccDebug.value) + } + + override def toString = s"Var$id$elems" + end Var + + abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context) + extends Var(initialElems): + def source: Var + + addSub(source) + + override def propagateSolved()(using Context) = + if source.isConst && !isConst then markSolved() + end DerivedVar + + /** A variable that changes when `source` changes, where all additional new elements are mapped + * using ∪ { f(x) | x <- elems } + */ + class Mapped private[CaptureSet] + (val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context) + extends DerivedVar(initial.elems): + addSub(initial) + val stack = if debugSets then (new Throwable).getStackTrace().take(20) else null + + private def whereCreated(using Context): String = + if stack == null then "" + else i""" + |Stack trace of variable creation:" + |${stack.mkString("\n")}""" + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val added = + if origin eq source then + mapRefs(newElems, tm, variance) + else + if variance <= 0 && !origin.isConst && (origin ne initial) then + report.warning(i"trying to add elems $newElems from unrecognized source $origin of mapped set $this$whereCreated") + return CompareResult.fail(this) + Const(newElems) + super.addNewElems(added.elems, origin) + .andAlso { + if added.isConst then CompareResult.OK + else if added.asVar.recordDepsState() then { addSub(added); CompareResult.OK } + else CompareResult.fail(this) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).map(tm) + + override def propagateSolved()(using Context) = + if initial.isConst then super.propagateSolved() + + override def toString = s"Mapped$id($source, elems = $elems)" + end Mapped + + class BiMapped private[CaptureSet] + (val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) + extends DerivedVar(initialElems): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if origin eq source then + super.addNewElems(newElems.map(bimap.forward), origin) + else + super.addNewElems(newElems, origin) + .andAlso { + source.tryInclude(newElems.map(bimap.backward), this) + .showing(i"propagating new elems $newElems backward from $this to $source", capt) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + val supApprox = super.computeApprox(this) + if source eq origin then supApprox.map(bimap.inverseTypeMap) + else source.upperApprox(this).map(bimap) ** supApprox + + override def toString = s"BiMapped$id($source, elems = $elems)" + end BiMapped + + /** A variable with elements given at any time as { x <- source.elems | p(x) } */ + class Filtered private[CaptureSet] + (val source: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context) + extends DerivedVar(source.elems.filter(p)): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + super.addNewElems(newElems.filter(p), origin) + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).filter(p) + + override def toString = s"${getClass.getSimpleName}$id($source, elems = $elems)" + end Filtered + + /** A variable with elements given at any time as { x <- source.elems | !other.accountsFor(x) } */ + class Diff(source: Var, other: Const)(using Context) + extends Filtered(source, !other.accountsFor(_)) + + /** A variable with elements given at any time as { x <- source.elems | other.accountsFor(x) } */ + class Intersected(source: Var, other: CaptureSet)(using Context) + extends Filtered(source, other.accountsFor(_)): + addSub(other) + + def extrapolateCaptureRef(r: CaptureRef, tm: TypeMap, variance: Int)(using Context): CaptureSet = + val r1 = tm(r) + val upper = r1.captureSet + def isExact = + upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1) + if variance > 0 || isExact then upper + else if variance < 0 then CaptureSet.empty + else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting") + + def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet = + ((empty: CaptureSet) /: xs)((cs, x) => cs ++ f(x)) + + def mapRefs(xs: Refs, tm: TypeMap, variance: Int)(using Context): CaptureSet = + mapRefs(xs, extrapolateCaptureRef(_, tm, variance)) + + type CompareResult = CompareResult.Type + + /** None = ok, Some(cs) = failure since not a subset of cs */ + object CompareResult: + opaque type Type = CaptureSet + val OK: Type = Const(emptySet) + def fail(cs: CaptureSet): Type = cs + extension (result: Type) + def isOK: Boolean = result eq OK + def blocking: CaptureSet = result + def show: String = if result.isOK then "OK" else result.toString + def andAlso(op: Context ?=> Type)(using Context): Type = if result.isOK then op else result + + class VarState: + private val elemsMap: util.EqHashMap[Var, Refs] = new util.EqHashMap + private val depsMap: util.EqHashMap[Var, Deps] = new util.EqHashMap + + def elems(v: Var): Refs = elemsMap(v) + def getElems(v: Var): Option[Refs] = elemsMap.get(v) + def putElems(v: Var, elems: Refs): Boolean = { elemsMap(v) = elems; true } + + def deps(v: Var): Deps = depsMap(v) + def getDeps(v: Var): Option[Deps] = depsMap.get(v) + def putDeps(v: Var, deps: Deps): Boolean = { depsMap(v) = deps; true } + + def abort(): Unit = + elemsMap.keysIterator.foreach(_.resetElems()(using this)) + depsMap.keysIterator.foreach(_.resetDeps()(using this)) + end VarState + + @sharable + object FrozenState extends VarState: + override def putElems(v: Var, refs: Refs) = false + override def putDeps(v: Var, deps: Deps) = false + override def abort(): Unit = () + + @sharable + object UnrecordedState extends VarState: + override def putElems(v: Var, refs: Refs) = true + override def putDeps(v: Var, deps: Deps) = true + override def abort(): Unit = () + + def varState(using state: VarState): VarState = state + + def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet = + CaptureSet.empty + /* + def captureSetOf(tp: Type): CaptureSet = tp match + case tp: TypeRef if tp.symbol.is(ParamAccessor) => + def mapArg(accs: List[Symbol], tps: List[Type]): CaptureSet = accs match + case acc :: accs1 if tps.nonEmpty => + if acc == tp.symbol then tps.head.captureSet + else mapArg(accs1, tps.tail) + case _ => + empty + mapArg(cinfo.cls.paramAccessors, argTypes) + case _ => + tp.captureSet + val css = + for + parent <- cinfo.parents if parent.classSymbol == defn.RetainingClass + arg <- parent.argInfos + yield captureSetOf(arg) + css.foldLeft(empty)(_ ++ _) + */ + def ofInfo(ref: CaptureRef)(using Context): CaptureSet = ref match + case ref: ThisType => + val declaredCaptures = ref.cls.givenSelfType.captureSet + ref.cls.paramAccessors.foldLeft(declaredCaptures) ((cs, acc) => + cs ++ acc.termRef.captureSetOfInfo) // ^^^ need to also include outer references of inner classes + .showing(i"cc info $ref with ${ref.cls.paramAccessors.map(_.termRef)}%, % = $result", capt) + case ref: TermRef if ref.isRootCapability => ref.singletonCaptureSet + case _ => ofType(ref.underlying) + + def ofType(tp: Type)(using Context): CaptureSet = + def recur(tp: Type): CaptureSet = tp.dealias match + case tp: TermRef => + tp.captureSet + case tp: TermParamRef => + tp.captureSet + case _: TypeRef | _: TypeParamRef => + empty + case CapturingType(parent, refs, _) => + recur(parent) ++ refs + case AppliedType(tycon, args) => + val cs = recur(tycon) + tycon.typeParams match + case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args) + case _ => cs + case tp: TypeProxy => + recur(tp.underlying) + case AndType(tp1, tp2) => + recur(tp1) ** recur(tp2) + case OrType(tp1, tp2) => + recur(tp1) ++ recur(tp2) + case tp: ClassInfo => + ofClass(tp, Nil) + case _ => + empty + recur(tp) + .showing(i"capture set of $tp = $result", capt) + + private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key() + + def withCaptureSetsExplained[T](op: Context ?=> T)(using ctx: Context): T = + if ctx.settings.YccDebug.value then + val shownVars = mutable.Set[Var]() + inContext(ctx.withProperty(ShownVars, Some(shownVars))) { + try op + finally + val reachable = mutable.Set[Var]() + val todo = mutable.Queue[Var]() ++= shownVars + def incl(cv: Var): Unit = + if !reachable.contains(cv) then todo += cv + while todo.nonEmpty do + val cv = todo.dequeue() + if !reachable.contains(cv) then + reachable += cv + cv.deps.foreach { + case cv: Var => incl(cv) + case _ => + } + cv match + case cv: DerivedVar => incl(cv.source) + case _ => + val allVars = reachable.toArray.sortBy(_.id) + println(i"Capture set dependencies:") + for cv <- allVars do + println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %") + } + else op +end CaptureSet diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala new file mode 100644 index 000000000000..2eeb1ff41b72 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -0,0 +1,21 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.* + +object CapturingType: + + def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type = + if refs.isAlwaysEmpty then parent + else AnnotatedType(parent, CaptureAnnotation(refs, boxed)) + + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = + if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then + tp.annot match + case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed)) + case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + else None + +end CapturingType diff --git a/compiler/src/dotty/tools/dotc/config/Config.scala b/compiler/src/dotty/tools/dotc/config/Config.scala index ac1708378e73..a54987b23ecc 100644 --- a/compiler/src/dotty/tools/dotc/config/Config.scala +++ b/compiler/src/dotty/tools/dotc/config/Config.scala @@ -227,4 +227,9 @@ object Config { * reduces the number of allocated denotations by ~50%. */ inline val reuseSymDenotations = true + + /** If true, print capturing types in the form `{c} T`. + * If false, print them in the form `T @retains(c)`. + */ + inline val printCaptureSetsAsPrefix = true } diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index b71e1e7f188a..d20d482b062e 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -12,6 +12,7 @@ object Printers { val default = new Printer + val capt = noPrinter val constr = noPrinter val core = noPrinter val checks = noPrinter diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 56e6ab14fae5..b0e161399e75 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -300,6 +300,8 @@ private sealed trait YSettings: val YcheckInit: Setting[Boolean] = BooleanSetting("-Ysafe-init", "Ensure safe initialization of objects") val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation") val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)") + val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references") + val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Debug info for captured references") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index b8d62210ce26..d0172c82972c 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -48,7 +48,7 @@ object Annotations { /** The tree evaluation has finished. */ def isEvaluated: Boolean = true - /** Normally, type map over all tree nodes of this annotation, but can + /** Normally, applies a type map to all tree nodes of this annotation, but can * be overridden. Returns EmptyAnnotation if type type map produces a range * type, since ranges cannot be types of trees. */ @@ -86,6 +86,10 @@ object Annotations { def sameAnnotation(that: Annotation)(using Context): Boolean = symbol == that.symbol && tree.sameTree(that.tree) + + /** Operations for hash-consing, can be overridden */ + def hash: Int = System.identityHashCode(this) + def eql(that: Annotation) = this eq that } case class ConcreteAnnotation(t: Tree) extends Annotation: diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index c3b462f3b179..ee296c47a305 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -14,6 +14,7 @@ import typer.ImportInfo.RootRef import Comments.CommentsContext import Comments.Comment import util.Spans.NoSpan +import cc.{CapturingType, CaptureSet} import scala.annotation.tailrec @@ -143,11 +144,13 @@ class Definitions { private def enterMethod(cls: ClassSymbol, name: TermName, info: Type, flags: FlagSet = EmptyFlags): TermSymbol = newMethod(cls, name, info, flags).entered - private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = { - val sym = newPermanentSymbol(ScalaPackageClass, name, flags, TypeAlias(tpe)) + private def enterPermanentSymbol(name: Name, info: Type, flags: FlagSet = EmptyFlags): Symbol = + val sym = newPermanentSymbol(ScalaPackageClass, name, flags, info) ScalaPackageClass.currentPackageDecls.enter(sym) sym - } + + private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = + enterPermanentSymbol(name, TypeAlias(tpe), flags).asType private def enterBinaryAlias(name: TypeName, op: (Type, Type) => Type): TypeSymbol = enterAliasType(name, @@ -440,6 +443,7 @@ class Definitions { @tu lazy val andType: TypeSymbol = enterBinaryAlias(tpnme.AND, AndType(_, _)) @tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _, soft = false)) + @tu lazy val captureRoot: TermSymbol = enterPermanentSymbol(nme.CAPTURE_ROOT, AnyType).asTerm /** Marker method to indicate an argument to a call-by-name parameter. * Created by byNameClosures and elimByName, eliminated by Erasure, @@ -941,6 +945,8 @@ class Definitions { @tu lazy val FunctionalInterfaceAnnot: ClassSymbol = requiredClass("java.lang.FunctionalInterface") @tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName") @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") + @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") + @tu lazy val AbilityAnnot: ClassSymbol = requiredClass("scala.annotation.ability") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") @@ -1514,6 +1520,9 @@ class Definitions { def isFunctionType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) + def isFunctionOrPolyType(tp: RefinedType)(using Context): Boolean = + isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass) + // Specialized type parameters defined for scala.Function{0,1,2}. @tu lazy val Function1SpecializedParamTypes: collection.Set[TypeRef] = Set(IntType, LongType, FloatType, DoubleType) @@ -1812,7 +1821,7 @@ class Definitions { this.initCtx = ctx if (!isInitialized) { // force initialization of every symbol that is synthesized or hijacked by the compiler - val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() :+ JavaEnumClass + val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() ++ List(JavaEnumClass, captureRoot) isInitialized = true } diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 17df7149e9be..b6f8f00a91fb 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -12,6 +12,7 @@ import config.Printers.constr import reflect.ClassTag import annotation.tailrec import annotation.internal.sharable +import cc.{CapturingType, derivedCapturingType} object OrderingConstraint { @@ -328,6 +329,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, case tp: TypeVar => val underlying1 = recur(tp.underlying, fromBelow) if underlying1 ne tp.underlying then underlying1 else tp + case CapturingType(parent, refs, _) => + val parent1 = recur(parent, fromBelow) + if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp case tp: AnnotatedType => val parent1 = recur(tp.parent, fromBelow) if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 0294de01f36e..a1faa1428188 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -13,10 +13,12 @@ import scala.collection.mutable.ListBuffer import dotty.tools.dotc.transform.MegaPhase._ import dotty.tools.dotc.transform._ import Periods._ -import parsing.{ Parser} +import parsing.Parser +import printing.XprintMode import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import typer.ImportInfo.withRootImports -import ast.tpd +import ast.{tpd, untpd} import scala.annotation.internal.sharable import scala.util.control.NonFatal @@ -216,6 +218,7 @@ object Phases { private var myCountOuterAccessesPhase: Phase = _ private var myFlattenPhase: Phase = _ private var myGenBCodePhase: Phase = _ + private var myCheckCapturesPhase: Phase = _ final def parserPhase: Phase = myParserPhase final def typerPhase: Phase = myTyperPhase @@ -238,6 +241,7 @@ object Phases { final def countOuterAccessesPhase = myCountOuterAccessesPhase final def flattenPhase: Phase = myFlattenPhase final def genBCodePhase: Phase = myGenBCodePhase + final def checkCapturesPhase: Phase = myCheckCapturesPhase private def setSpecificPhases() = { def phaseOfClass(pclass: Class[?]) = phases.find(pclass.isInstance).getOrElse(NoPhase) @@ -262,7 +266,8 @@ object Phases { myFlattenPhase = phaseOfClass(classOf[Flatten]) myExplicitOuterPhase = phaseOfClass(classOf[ExplicitOuter]) myGettersPhase = phaseOfClass(classOf[Getters]) - myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myCheckCapturesPhase = phaseOfClass(classOf[CheckCaptures]) } final def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id @@ -312,6 +317,10 @@ object Phases { unitCtx.compilationUnit } + /** Convert a compilation unit's tree to a string; can be overridden */ + def show(tree: untpd.Tree)(using Context): String = + tree.show(using ctx.withProperty(XprintMode, Some(()))) + def description: String = phaseName /** Output should be checkable by TreeChecker */ @@ -438,6 +447,7 @@ object Phases { def lambdaLiftPhase(using Context): Phase = ctx.base.lambdaLiftPhase def flattenPhase(using Context): Phase = ctx.base.flattenPhase def genBCodePhase(using Context): Phase = ctx.base.genBCodePhase + def checkCapturesPhase(using Context): Phase = ctx.base.checkCapturesPhase def unfusedPhases(using Context): Array[Phase] = ctx.base.phases diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 2e0b229ca42c..2a575cd0ad4d 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -275,6 +275,7 @@ object StdNames { // Compiler-internal val ANYname: N = "" + val CAPTURE_ROOT: N = "*" val COMPANION: N = "" val CONSTRUCTOR: N = "" val STATIC_CONSTRUCTOR: N = "" @@ -362,6 +363,7 @@ object StdNames { val AppliedTypeTree: N = "AppliedTypeTree" val ArrayAnnotArg: N = "ArrayAnnotArg" val CAP: N = "CAP" + val ClassManifestFactory: N = "ClassManifestFactory" val Constant: N = "Constant" val ConstantType: N = "ConstantType" val Eql: N = "Eql" @@ -439,7 +441,6 @@ object StdNames { val canEqualAny : N = "canEqualAny" val cbnArg: N = "" val checkInitialized: N = "checkInitialized" - val ClassManifestFactory: N = "ClassManifestFactory" val classOf: N = "classOf" val classType: N = "classType" val clone_ : N = "clone" @@ -571,6 +572,7 @@ object StdNames { val reflectiveSelectable: N = "reflectiveSelectable" val reify : N = "reify" val releaseFence : N = "releaseFence" + val retains: N = "retains" val rootMirror : N = "rootMirror" val run: N = "run" val runOrElse: N = "runOrElse" diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index f00edcb189c6..b277f2cd8619 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -161,8 +161,9 @@ object Substituters: .mapOver(tp) } - final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap { + final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) + def inverse(tp: Type): Type = tp.subst(to, from) } final class Subst1Map(from: Symbol, to: Type)(using Context) extends DeepTypeMap { @@ -177,8 +178,9 @@ object Substituters: def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) } - final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap { + final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = substSym(tp, from, to, this)(using mapCtx) + def inverse(tp: Type) = tp.substSym(to, from) } final class SubstThisMap(from: ClassSymbol, to: Type)(using Context) extends DeepTypeMap { diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index fcfbba208eb9..f5e5ea31d845 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -24,6 +24,7 @@ import config.Config import reporting._ import collection.mutable import transform.TypeUtils._ +import cc.{CapturingType, derivedCapturingType} import scala.annotation.internal.sharable @@ -224,6 +225,8 @@ object SymDenotations { ensureCompleted(); myAnnotations } + final def annotationsUNSAFE(using Context): List[Annotation] = myAnnotations + /** Update the annotations of this denotation */ final def annotations_=(annots: List[Annotation]): Unit = myAnnotations = annots @@ -1509,8 +1512,7 @@ object SymDenotations { case tp: ExprType => hasSkolems(tp.resType) case tp: AppliedType => hasSkolems(tp.tycon) || tp.args.exists(hasSkolems) case tp: LambdaType => tp.paramInfos.exists(hasSkolems) || hasSkolems(tp.resType) - case tp: AndType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) - case tp: OrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) + case tp: AndOrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) case tp: AnnotatedType => hasSkolems(tp.parent) case _ => false } @@ -2166,6 +2168,9 @@ object SymDenotations { case tp: TypeParamRef => // uncachable, since baseType depends on context bounds recur(TypeComparer.bounds(tp).hi) + case CapturingType(parent, refs, _) => + tp.derivedCapturingType(recur(parent), refs) + case tp: TypeProxy => def computeTypeProxy = { val superTp = tp.superType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index add2030b6a82..b942526fa59e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -24,6 +24,7 @@ import typer.Applications.productSelectorTypes import reporting.trace import NullOpsDecorator._ import annotation.constructorOnly +import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing} /** Provides methods to compare types. */ @@ -325,6 +326,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling compareWild case tp2: LazyRef => isBottom(tp1) || !tp2.evaluating && recur(tp1, tp2.ref) + case CapturingType(_, _, _) => + secondTry case tp2: AnnotatedType if !tp2.isRefining => recur(tp1, tp2.parent) case tp2: ThisType => @@ -444,8 +447,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // See i859.scala for an example where we hit this case. tp2.isRef(AnyClass, skipRefined = false) || !tp1.evaluating && recur(tp1.ref, tp2) - case tp1: AnnotatedType if !tp1.isRefining => - recur(tp1.parent, tp2) case AndType(tp11, tp12) => if (tp11.stripTypeVar eq tp12.stripTypeVar) recur(tp11, tp2) else thirdTry @@ -489,7 +490,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // and then need to check that they are indeed supertypes of the original types // under -Ycheck. Test case is i7965.scala. - case tp1: MatchType => + case CapturingType(parent1, refs1, _) => + if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK then + recur(parent1, tp2) + else + thirdTry + case tp1: AnnotatedType if !tp1.isRefining => + recur(tp1.parent, tp2) + case tp1: MatchType => val reduced = tp1.reduced if (reduced.exists) recur(reduced, tp2) else thirdTry case _: FlexType => @@ -527,8 +535,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // Note: We would like to replace this by `if (tp1.hasHigherKind)` // but right now we cannot since some parts of the standard library rely on the // idiom that e.g. `List <: Any`. We have to bootstrap without scalac first. - if (cls2 eq AnyClass) return true - if (cls2 == defn.SingletonClass && tp1.isStable) return true + if cls2 eq AnyClass then return true + if cls2 == defn.SingletonClass && tp1.isStable then return true return tryBaseType(cls2) } else if (cls2.is(JavaDefined)) { @@ -597,6 +605,28 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubRefinements(tp1w.asInstanceOf[RefinedType], tp2, skipped2) && recur(tp1, skipped2) + def isSubInfo(info1: Type, info2: Type): Boolean = (info1, info2) match + case (info1: PolyType, info2: PolyType) => + sameLength(info1.paramNames, info2.paramNames) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case (info1: MethodType, info2: MethodType) => + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case _ => + isSubType(info1, info2) + + if ctx.phase == Phases.checkCapturesPhase then + if defn.isFunctionType(tp2) then + tp1.widenDealias match + case tp1: RefinedType => + return isSubInfo(tp1.refinedInfo, tp2.refinedInfo) + case _ => + else if tp2.parent.typeSymbol == defn.PolyFunctionClass then + tp1.member(nme.apply).info match + case info1: PolyType => + return isSubInfo(info1, tp2.refinedInfo) + case _ => + compareRefined case tp2: RecType => def compareRec = tp1.safeDealias match { @@ -727,13 +757,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def compareTypeBounds = tp1 match { case tp1 @ TypeBounds(lo1, hi1) => ((lo2 eq NothingType) || isSubType(lo2, lo1)) && - ((hi2 eq AnyType) && !hi1.isLambdaSub || (hi2 eq AnyKindType) || isSubType(hi1, hi2)) + ((hi2 eq AnyType) && !hi1.isLambdaSub + || (hi2 eq AnyKindType) + || isSubType(hi1, hi2)) case tp1: ClassInfo => tp2 contains tp1 case _ => false } compareTypeBounds + case CapturingType(parent2, _, _) => + recur(tp1, parent2) || fourthTry case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) @@ -780,6 +814,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp: AppliedType => isNullable(tp.tycon) case AndType(tp1, tp2) => isNullable(tp1) && isNullable(tp2) case OrType(tp1, tp2) => isNullable(tp1) || isNullable(tp2) + case CapturingType(tp1, _, _) => isNullable(tp1) case _ => false } val sym1 = tp1.symbol @@ -798,7 +833,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false } case _ => false - comparePaths || isSubType(tp1.underlying.widenExpr, tp2, approx.addLow) + comparePaths || { + var tp1w = tp1.underlying.widenExpr + tp1 match + case tp1: CaptureRef if tp1.isTracked => + val stripped = tp1w.stripCapturing + tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false) + case _ => + isSubType(tp1w, tp2, approx.addLow) + } case tp1: RefinedType => isNewSubType(tp1.parent) case tp1: RecType => @@ -1769,69 +1812,68 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def hasMatchingMember(name: Name, tp1: Type, tp2: RefinedType): Boolean = trace(i"hasMatchingMember($tp1 . $name :? ${tp2.refinedInfo}), mbr: ${tp1.member(name).info}", subtyping) { - def qualifies(m: SingleDenotation): Boolean = - // If the member is an abstract type and the prefix is a path, compare the member itself - // instead of its bounds. This case is needed situations like: - // - // class C { type T } - // val foo: C - // foo.type <: C { type T {= , <: , >:} foo.T } - // - // or like: - // - // class C[T] - // C[?] <: C[TV] - // - // where TV is a type variable. See i2397.scala for an example of the latter. - def matchAbstractTypeMember(info1: Type): Boolean = info1 match { - case TypeBounds(lo, hi) if lo ne hi => - tp2.refinedInfo match { - case rinfo2: TypeBounds if tp1.isStable => - val ref1 = tp1.widenExpr.select(name) - isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) - case _ => - false - } - case _ => false - } + // If the member is an abstract type and the prefix is a path, compare the member itself + // instead of its bounds. This case is needed situations like: + // + // class C { type T } + // val foo: C + // foo.type <: C { type T {= , <: , >:} foo.T } + // + // or like: + // + // class C[T] + // C[?] <: C[TV] + // + // where TV is a type variable. See i2397.scala for an example of the latter. + def matchAbstractTypeMember(info1: Type): Boolean = info1 match { + case TypeBounds(lo, hi) if lo ne hi => + tp2.refinedInfo match { + case rinfo2: TypeBounds if tp1.isStable => + val ref1 = tp1.widenExpr.select(name) + isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) + case _ => + false + } + case _ => false + } - // An additional check for type member matching: If the refinement of the - // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. - // then the symbol referred to in the subtype must have a signature that coincides - // in its parameters with the refinement's signature. The reason for the check - // is that if the refinement does not refer to a member symbol, we will have to - // resort to reflection to invoke the member. And Java reflection needs to know exact - // erased parameter types. See neg/i12211.scala. Other reflection algorithms could - // conceivably dispatch without knowning precise parameter signatures. One can signal - // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, - // in which case the signature test is elided. - def sigsOK(symInfo: Type, info2: Type) = - tp2.underlyingClassRef(refinementOK = true).member(name).exists - || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) - || symInfo.isInstanceOf[MethodType] - && symInfo.signature.consistentParams(info2.signature) - - // A relaxed version of isSubType, which compares method types - // under the standard arrow rule which is contravarient in the parameter types, - // but under the condition that signatures might have to match (see sigsOK) - // This relaxed version is needed to correctly compare dependent function types. - // See pos/i12211.scala. - def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = - info2 match - case info2: MethodType => - info1 match - case info1: MethodType => - val symInfo1 = symInfo.stripPoly - matchingMethodParams(info1, info2, precise = false) - && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) - && sigsOK(symInfo1, info2) - case _ => isSubType(info1, info2) - case _ => isSubType(info1, info2) + // An additional check for type member matching: If the refinement of the + // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. + // then the symbol referred to in the subtype must have a signature that coincides + // in its parameters with the refinement's signature. The reason for the check + // is that if the refinement does not refer to a member symbol, we will have to + // resort to reflection to invoke the member. And Java reflection needs to know exact + // erased parameter types. See neg/i12211.scala. Other reflection algorithms could + // conceivably dispatch without knowning precise parameter signatures. One can signal + // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, + // in which case the signature test is elided. + def sigsOK(symInfo: Type, info2: Type) = + tp2.underlyingClassRef(refinementOK = true).member(name).exists + || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) + || symInfo.isInstanceOf[MethodType] + && symInfo.signature.consistentParams(info2.signature) + + // A relaxed version of isSubType, which compares method types + // under the standard arrow rule which is contravarient in the parameter types, + // but under the condition that signatures might have to match (see sigsOK) + // This relaxed version is needed to correctly compare dependent function types. + // See pos/i12211.scala. + def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = + info2 match + case info2: MethodType => + info1 match + case info1: MethodType => + val symInfo1 = symInfo.stripPoly + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) + && sigsOK(symInfo1, info2) + case _ => isSubType(info1, info2) + case _ => isSubType(info1, info2) + def qualifies(m: SingleDenotation): Boolean = val info1 = m.info.widenExpr isSubInfo(info1, tp2.refinedInfo.widenExpr, m.symbol.info.orElse(info1)) || matchAbstractTypeMember(m.info) - end qualifies tp1.member(name) match // inlined hasAltWith for performance case mbr: SingleDenotation => qualifies(mbr) @@ -1956,8 +1998,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case formal2 :: rest2 => val formal2a = if (tp2.isParamDependent) formal2.subst(tp2, tp1) else formal2 val paramsMatch = - if precise then isSameTypeWhenFrozen(formal1, formal2a) - else isSubTypeWhenFrozen(formal2a, formal1) + if precise then + isSameTypeWhenFrozen(formal1, formal2a) + else if ctx.phase == Phases.checkCapturesPhase then + isSubType(formal2a, formal1) + else + isSubTypeWhenFrozen(formal2a, formal1) paramsMatch && loop(rest1, rest2) case nil => false @@ -2360,6 +2406,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } case tp1: TypeVar if tp1.isInstantiated => tp1.underlying & tp2 + case CapturingType(parent1, refs1, _) => + if subCaptures(tp2.captureSet, refs1, frozenConstraint).isOK then + parent1 & tp2 + else + tp1.derivedCapturingType(parent1 & tp2, refs1) case tp1: AnnotatedType if !tp1.isRefining => tp1.underlying & tp2 case _ => @@ -2422,6 +2473,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling false } + protected def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + refs1.subCaptures(refs2, frozen) + // ----------- Diagnostics -------------------------------------------------- /** A hook for showing subtype traces. Overridden in ExplainingTypeComparer */ @@ -2687,6 +2741,7 @@ object TypeComparer { else res match case ClassInfo(_, cls, _, _, _) => cls.showLocated case bounds: TypeBounds => i"type bounds [$bounds]" + case CaptureSet.CompareResult.OK => "OK" case res: printing.Showable => res.show case _ => String.valueOf(res) @@ -3015,5 +3070,10 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.addConstraint(param, bound, fromBelow) } + override def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + traceIndented(i"subcaptures $refs1 <:< $refs2 ${if frozen then "frozen" else ""}") { + super.subCaptures(refs1, refs2, frozen) + } + def lastTrace(header: String): String = header + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala index c9ca98f65f5e..9067d0c87142 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala @@ -73,6 +73,7 @@ class RecursionOverflow(val op: String, details: => String, val previous: Throwa s"""Recursion limit exceeded. |Maybe there is an illegal cyclic reference? |If that's not the case, you could also try to increase the stacksize using the -Xss JVM option. + |For the unprocessed stack trace, compile with -Yno-decode-stacktraces. |A recurring operation is (inner to outer): |${opsString(mostCommon)}""".stripMargin } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 6a5145ffd202..dfdcb5d38054 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -19,6 +19,8 @@ import typer.ForceDegree import typer.Inferencing._ import typer.IfBottom import reporting.TestingReporter +import cc.{CapturingType, derivedCapturingType, CaptureSet} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -164,6 +166,12 @@ object TypeOps: // with Nulls (which have no base classes). Under -Yexplicit-nulls, we take // corrective steps, so no widening is wanted. simplify(l, theMap) | simplify(r, theMap) + case CapturingType(parent, refs, _) => + if !ctx.mode.is(Mode.Type) + && refs.subCaptures(parent.captureSet, frozen = true).isOK then + simplify(parent, theMap) + else + mapOver case tp @ AnnotatedType(parent, annot) => val parent1 = simplify(parent, theMap) if annot.symbol == defn.UncheckedVarianceAnnot @@ -273,15 +281,23 @@ object TypeOps: case _ => false } - // Step 1: Get RecTypes and ErrorTypes out of the way, + // Step 1: Get RecTypes and ErrorTypes and CapturingTypes out of the way, tp1 match { - case tp1: RecType => return tp1.rebind(approximateOr(tp1.parent, tp2)) - case err: ErrorType => return err + case tp1: RecType => + return tp1.rebind(approximateOr(tp1.parent, tp2)) + case CapturingType(parent1, refs1, _) => + return tp1.derivedCapturingType(approximateOr(parent1, tp2), refs1) + case err: ErrorType => + return err case _ => } tp2 match { - case tp2: RecType => return tp2.rebind(approximateOr(tp1, tp2.parent)) - case err: ErrorType => return err + case tp2: RecType => + return tp2.rebind(approximateOr(tp1, tp2.parent)) + case CapturingType(parent2, refs2, _) => + return tp2.derivedCapturingType(approximateOr(tp1, parent2), refs2) + case err: ErrorType => + return err case _ => } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index d3c0eeab73d9..84dec9e43784 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,6 +38,8 @@ import scala.util.hashing.{ MurmurHash3 => hashing } import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference +import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -67,7 +69,7 @@ object Types { * | | +--- SkolemType * | +- TypeParamRef * | +- RefinedOrRecType -+-- RefinedType - * | | -+-- RecType + * | | +-- RecType * | +- AppliedType * | +- TypeBounds * | +- ExprType @@ -187,7 +189,7 @@ object Types { * It makes no sense for it to be an alias type because isRef would always * return false in that case. */ - def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = stripped match { + def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = this match { case this1: TypeRef => this1.info match { // see comment in Namer#typeDefSig case TypeAlias(tp) => tp.isRef(sym, skipRefined) @@ -199,6 +201,12 @@ object Types { val this2 = this1.dealias if (this2 ne this1) this2.isRef(sym, skipRefined) else this1.underlying.isRef(sym, skipRefined) + case this1: TypeVar => + this1.instanceOpt.isRef(sym, skipRefined) + case this1: AnnotatedType => + this1 match + case CapturingType(_, _, _) => false + case _ => this1.parent.isRef(sym, skipRefined) case _ => false } @@ -365,6 +373,7 @@ object Types { case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference) case WildcardType(optBounds) => optBounds.unusableForInference + case CapturingType(parent, refs, _) => parent.unusableForInference || refs.elems.exists(_.unusableForInference) case _: ErrorType => true case _ => false @@ -1174,9 +1183,13 @@ object Types { */ def stripAnnots(using Context): Type = this - /** Strip TypeVars and Annotation wrappers */ + /** Strip TypeVars and Annotation and CapturingType wrappers */ def stripped(using Context): Type = this + def strippedDealias(using Context): Type = + val tp1 = stripped.dealias + if tp1 ne this then tp1.strippedDealias else this + def rewrapAnnots(tp: Type)(using Context): Type = tp.stripTypeVar match { case AnnotatedType(tp1, annot) => AnnotatedType(rewrapAnnots(tp1), annot) case _ => this @@ -1367,8 +1380,13 @@ object Types { val tp1 = tp.instanceOpt if (tp1.exists) tp1.dealias1(keep) else tp case tp: AnnotatedType => - val tp1 = tp.parent.dealias1(keep) - if keep(tp) then tp.derivedAnnotatedType(tp1, tp.annot) else tp1 + val parent1 = tp.parent.dealias1(keep) + tp match + case tp @ CapturingType(parent, refs, _) => + tp.derivedCapturingType(parent1, refs) + case _ => + if keep(tp) then tp.derivedAnnotatedType(parent1, tp.annot) + else parent1 case tp: LazyRef => tp.ref.dealias1(keep) case _ => this @@ -1461,7 +1479,7 @@ object Types { if (tp.tycon.isLambdaSub) NoType else tp.superType.underlyingClassRef(refinementOK) case tp: AnnotatedType => - tp.underlying.underlyingClassRef(refinementOK) + tp.parent.underlyingClassRef(refinementOK) case tp: RefinedType => if (refinementOK) tp.underlying.underlyingClassRef(refinementOK) else NoType case tp: RecType => @@ -1504,6 +1522,8 @@ object Types { case _ => if (isRepeatedParam) this.argTypesHi.head else this } + def captureSet(using Context): CaptureSet = CaptureSet.ofType(this) + // ----- Normalizing typerefs over refined types ---------------------------- /** If this normalizes* to a refinement type that has a refinement for `name` (which might be followed @@ -1791,7 +1811,7 @@ object Types { * @param dropLast The number of trailing parameters that should be dropped * when forming the function type. */ - def toFunctionType(isJava: Boolean, dropLast: Int = 0)(using Context): Type = this match { + def toFunctionType(isJava: Boolean, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match { case mt: MethodType if !mt.isParamDependent => val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast val isContextual = mt.isContextualMethod && !ctx.erasedTypes @@ -1803,7 +1823,7 @@ object Types { val funType = defn.FunctionOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), result1, isContextual, isErased) - if (mt.isResultDependent) RefinedType(funType, nme.apply, mt) + if alwaysDependent || mt.isResultDependent then RefinedType(funType, nme.apply, mt) else funType } @@ -1835,6 +1855,16 @@ object Types { case _ => this } + def capturing(ref: CaptureRef)(using Context): Type = + if captureSet.accountsFor(ref) then this + else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing) + + def capturing(cs: CaptureSet)(using Context): Type = + if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this + else this match + case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs) + case _ => CapturingType(this, cs, this.isBoxedCapturing) + /** The set of distinct symbols referred to by this type, after all aliases are expanded */ def coveringSet(using Context): Set[Symbol] = (new CoveringSetAccumulator).apply(Set.empty[Symbol], this) @@ -2015,6 +2045,40 @@ object Types { def isOverloaded(using Context): Boolean = false } + /** A trait for references in CaptureSets. These can be NamedTypes, ThisTypes or ParamRefs */ + trait CaptureRef extends SingletonType: + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + private var mySingletonCaptureSet: CaptureSet.Const = null + + def canBeTracked(using Context): Boolean + final def isTracked(using Context): Boolean = canBeTracked && !captureSetOfInfo.isAlwaysEmpty + def isRootCapability(using Context): Boolean = false + def normalizedRef(using Context): CaptureRef = this + + def singletonCaptureSet(using Context): CaptureSet.Const = + if mySingletonCaptureSet == null then + mySingletonCaptureSet = CaptureSet(this.normalizedRef) + mySingletonCaptureSet + + def captureSetOfInfo(using Context): CaptureSet = + if ctx.runId == myCaptureSetRunId then myCaptureSet + else if myCaptureSet eq CaptureSet.Pending then CaptureSet.empty + else + myCaptureSet = CaptureSet.Pending + val computed = CaptureSet.ofInfo(this) + if ctx.phase != Phases.checkCapturesPhase || underlying.isProvisional then + myCaptureSet = null + else + myCaptureSet = computed + myCaptureSetRunId = ctx.runId + computed + + override def captureSet(using Context): CaptureSet = + val cs = captureSetOfInfo + if canBeTracked && !cs.isAlwaysEmpty then singletonCaptureSet else cs + end CaptureRef + /** A trait for types that bind other types that refer to them. * Instances are: LambdaType, RecType. */ @@ -2062,7 +2126,7 @@ object Types { // --- NamedTypes ------------------------------------------------------------------ - abstract class NamedType extends CachedProxyType with ValueType { self => + abstract class NamedType extends CachedProxyType, ValueType { self => type ThisType >: this.type <: NamedType type ThisName <: Name @@ -2081,6 +2145,9 @@ object Types { private var mySignature: Signature = _ private var mySignatureRunId: Int = NoRunId + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + // Invariants: // (1) checkedPeriod != Nowhere => lastDenotation != null // (2) lastDenotation != null => lastSymbol != null @@ -2433,7 +2500,7 @@ object Types { val tparam = symbol val cls = tparam.owner val base = pre.baseType(cls) - base match { + base.stripped match { case AppliedType(_, allArgs) => var tparams = cls.typeParams var args = allArgs @@ -2613,7 +2680,7 @@ object Types { */ abstract case class TermRef(override val prefix: Type, private var myDesignator: Designator) - extends NamedType with SingletonType with ImplicitRef { + extends NamedType, ImplicitRef, CaptureRef { type ThisType = TermRef type ThisName = TermName @@ -2637,6 +2704,25 @@ object Types { def implicitName(using Context): TermName = name def underlyingRef: TermRef = this + + /** A term reference can be tracked if it is a local term ref to a value + * or a method term parameter. References to term parameters of classes + * cannot be tracked individually. + * They are subsumed in the capture sets of the enclosing class. + * TODO: ^^^ What avout call-by-name? + */ + def canBeTracked(using Context) = + ((prefix eq NoPrefix) + || symbol.is(ParamAccessor) && (prefix eq symbol.owner.thisType) + || symbol.hasAnnotation(defn.AbilityAnnot) + || isRootCapability + ) && !symbol.is(Method) + + override def isRootCapability(using Context): Boolean = + name == nme.CAPTURE_ROOT && symbol == defn.captureRoot + + override def normalizedRef(using Context): CaptureRef = + if canBeTracked then symbol.termRef else this } abstract case class TypeRef(override val prefix: Type, @@ -2772,7 +2858,7 @@ object Types { * Note: we do not pass a class symbol directly, because symbols * do not survive runs whereas typerefs do. */ - abstract case class ThisType(tref: TypeRef) extends CachedProxyType with SingletonType { + abstract case class ThisType(tref: TypeRef) extends CachedProxyType, CaptureRef { def cls(using Context): ClassSymbol = tref.stableInRunSymbol match { case cls: ClassSymbol => cls case _ if ctx.mode.is(Mode.Interactive) => defn.AnyClass // was observed to happen in IDE mode @@ -2786,6 +2872,8 @@ object Types { // can happen in IDE if `cls` is stale } + def canBeTracked(using Context) = true + override def computeHash(bs: Binders): Int = doHash(bs, tref) override def eql(that: Type): Boolean = that match { @@ -3612,9 +3700,17 @@ object Types { case tp: AppliedType => tp.fold(status, compute(_, _, theAcc)) case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps - case AnnotatedType(parent, ann) => - if ann.refersToParamOf(thisLambdaType) then TrueDeps - else compute(status, parent, theAcc) + case tp: AnnotatedType => + tp match + case CapturingType(parent, refs, _) => + (compute(status, parent, theAcc) /: refs.elems) { + (s, ref) => ref match + case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps) + case _ => s + } + case _ => + if tp.annot.refersToParamOf(thisLambdaType) then TrueDeps + else compute(status, tp.parent, theAcc) case _: ThisType | _: BoundType | NoPrefix => status case _ => (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) @@ -3653,29 +3749,52 @@ object Types { /** Does result type contain references to parameters of this method type, * which cannot be eliminated by de-aliasing? */ - def isResultDependent(using Context): Boolean = dependencyStatus == TrueDeps + def isResultDependent(using Context): Boolean = + dependencyStatus == TrueDeps || dependencyStatus == CaptureDeps /** Does one of the parameter types contain references to earlier parameters * of this method type which cannot be eliminated by de-aliasing? */ def isParamDependent(using Context): Boolean = paramDependencyStatus == TrueDeps + /** Is there either a true or false type dependency, or does the result + * type capture a parameter? + */ + def isCaptureDependent(using Context) = dependencyStatus == CaptureDeps + def newParamRef(n: Int): TermParamRef = new TermParamRefImpl(this, n) /** The least supertype of `resultType` that does not contain parameter dependencies */ def nonDependentResultApprox(using Context): Type = - if (isResultDependent) { + if isResultDependent then val dropDependencies = new ApproximatingTypeMap { def apply(tp: Type) = tp match { case tp @ TermParamRef(`thisLambdaType`, _) => range(defn.NothingType, atVariance(1)(apply(tp.underlying))) + case CapturingType(parent, refs, boxed) => + val parent1 = this(parent) + val elems1 = refs.elems.filter { + case tp @ TermParamRef(`thisLambdaType`, _) => false + case _ => true + } + if elems1.size == refs.elems.size then + derivedCapturingType(tp, parent1, refs) + else + range( + CapturingType(parent1, CaptureSet(elems1), boxed), + CapturingType(parent1, CaptureSet.universal, boxed)) case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) => - mapOver(parent) + val parent1 = mapOver(parent) + if ann.symbol == defn.RetainsAnnot then + range( + AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation), + AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation)) + else + parent1 case _ => mapOver(tp) } } dropDependencies(resultType) - } else resultType } @@ -4046,9 +4165,10 @@ object Types { final val Unknown: DependencyStatus = 0 // not yet computed final val NoDeps: DependencyStatus = 1 // no dependent parameters found final val FalseDeps: DependencyStatus = 2 // all dependent parameters are prefixes of non-depended alias types - final val TrueDeps: DependencyStatus = 3 // some truly dependent parameters exist - final val StatusMask: DependencyStatus = 3 // the bits indicating actual dependency status - final val Provisional: DependencyStatus = 4 // set if dependency status can still change due to type variable instantiations + final val CaptureDeps: DependencyStatus = 3 + final val TrueDeps: DependencyStatus = 4 // some truly dependent parameters exist + final val StatusMask: DependencyStatus = 7 // the bits indicating actual dependency status + final val Provisional: DependencyStatus = 8 // set if dependency status can still change due to type variable instantiations } // ----- Type application: LambdaParam, AppliedType --------------------- @@ -4370,8 +4490,9 @@ object Types { /** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to * refer to `TermParamRef(binder, paramNum)`. */ - abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef with SingletonType { + abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef, CaptureRef { type BT = TermLambda + def canBeTracked(using Context) = true def kindString: String = "Term" def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum) } @@ -5040,7 +5161,7 @@ object Types { // ----- Annotated and Import types ----------------------------------------------- /** An annotated type tpe @ annot */ - abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType with ValueType { + abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType, ValueType { override def underlying(using Context): Type = parent @@ -5069,16 +5190,16 @@ object Types { // equals comes from case class; no matching override is needed override def computeHash(bs: Binders): Int = - doHash(bs, System.identityHashCode(annot), parent) + doHash(bs, annot.hash, parent) override def hashIsStable: Boolean = parent.hashIsStable override def eql(that: Type): Boolean = that match - case that: AnnotatedType => (parent eq that.parent) && (annot eq that.annot) + case that: AnnotatedType => (parent eq that.parent) && (annot eql that.annot) case _ => false override def iso(that: Any, bs: BinderPairs): Boolean = that match - case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eq that.annot) + case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eql that.annot) case _ => false } @@ -5089,6 +5210,7 @@ object Types { annots.foldLeft(underlying)(apply(_, _)) def apply(parent: Type, annot: Annotation)(using Context): AnnotatedType = unique(CachedAnnotatedType(parent, annot)) + end AnnotatedType // Special type objects and classes ----------------------------------------------------- @@ -5308,7 +5430,7 @@ object Types { /** Common base class of TypeMap and TypeAccumulator */ abstract class VariantTraversal: - protected[core] var variance: Int = 1 + protected[dotc] var variance: Int = 1 inline protected def atVariance[T](v: Int)(op: => T): T = { val saved = variance @@ -5334,6 +5456,24 @@ object Types { } end VariantTraversal + /** A supertrait for some typemaps that are bijections. Used for capture checking + * BiTypeMaps should map capture references to capture references. + */ + trait BiTypeMap extends TypeMap: + thisMap => + def inverse(tp: Type): Type + + def inverseTypeMap(using Context) = new BiTypeMap: + def apply(tp: Type) = thisMap.inverse(tp) + def inverse(tp: Type) = thisMap.apply(tp) + + def forward(ref: CaptureRef): CaptureRef = this(ref) match + case result: CaptureRef if result.canBeTracked => result + + def backward(ref: CaptureRef): CaptureRef = inverse(ref) match + case result: CaptureRef if result.canBeTracked => result + end BiTypeMap + abstract class TypeMap(implicit protected var mapCtx: Context) extends VariantTraversal with (Type => Type) { thisMap => @@ -5361,6 +5501,8 @@ object Types { tp.derivedMatchType(bound, scrutinee, cases) protected def derivedAnnotatedType(tp: AnnotatedType, underlying: Type, annot: Annotation): Type = tp.derivedAnnotatedType(underlying, annot) + protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + tp.derivedCapturingType(parent, refs) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -5396,6 +5538,12 @@ object Types { def isRange(tp: Type): Boolean = tp.isInstanceOf[Range] + protected def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + val saved = variance + variance = v + try derivedCapturingType(tp, this(parent), refs.map(this)) + finally variance = saved + /** Map this function over given type */ def mapOver(tp: Type): Type = { record(s"TypeMap mapOver ${getClass}") @@ -5437,6 +5585,9 @@ object Types { case tp: ExprType => derivedExprType(tp, this(tp.resultType)) + case CapturingType(parent, refs, _) => + mapCapturingType(tp, parent, refs, variance) + case tp @ AnnotatedType(underlying, annot) => val underlying1 = this(underlying) val annot1 = annot.mapWith(this) @@ -5757,6 +5908,13 @@ object Types { if (underlying.isExactlyNothing) underlying else tp.derivedAnnotatedType(underlying, annot) } + override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + parent match // ^^^ handle ranges in capture sets as well + case Range(lo, hi) => + range(derivedCapturingType(tp, lo, refs), derivedCapturingType(tp, hi, refs)) + case _ => + tp.derivedCapturingType(parent, refs) + override protected def derivedWildcardType(tp: WildcardType, bounds: Type): WildcardType = tp.derivedWildcardType(rangeToBounds(bounds)) @@ -5796,6 +5954,12 @@ object Types { tp.derivedLambdaType(tp.paramNames, formals, restpe) } + override def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + if v == 0 then + range(mapCapturingType(tp, parent, refs, -1), mapCapturingType(tp, parent, refs, 1)) + else + super.mapCapturingType(tp, parent, refs, v) + protected def reapply(tp: Type): Type = apply(tp) } @@ -5893,6 +6057,9 @@ object Types { val x2 = atVariance(0)(this(x1, tp.scrutinee)) foldOver(x2, tp.cases) + case CapturingType(parent, refs, _) => + (this(x, parent) /: refs.elems)(this) + case AnnotatedType(underlying, annot) => this(applyToAnnot(x, annot), underlying) diff --git a/compiler/src/dotty/tools/dotc/core/Variances.scala b/compiler/src/dotty/tools/dotc/core/Variances.scala index 122c7a10e4b7..44dda6b0077e 100644 --- a/compiler/src/dotty/tools/dotc/core/Variances.scala +++ b/compiler/src/dotty/tools/dotc/core/Variances.scala @@ -4,6 +4,7 @@ package core import Types._, Contexts._, Flags._, Symbols._, Annotations._ import TypeApplications.TypeParamInfo import Decorators._ +import cc.CapturingType object Variances { @@ -99,6 +100,8 @@ object Variances { v } varianceInArgs(varianceInType(tycon)(tparam), args, tycon.typeParams) + case CapturingType(tp, _, _) => + varianceInType(tp)(tparam) case AnnotatedType(tp, annot) => varianceInType(tp)(tparam) & varianceInAnnot(annot)(tparam) case AndType(tp1, tp2) => diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index e4a5c0ae8c6d..bbc18cd3f5ff 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -821,7 +821,7 @@ class TreeUnpickler(reader: TastyReader, def TypeDef(rhs: Tree) = ta.assignType(untpd.TypeDef(sym.name.asTypeName, rhs), sym) - def ta = ctx.typeAssigner + def ta = ctx.typeAssigner val name = readName() pickling.println(s"reading def of $name at $start") @@ -1263,11 +1263,9 @@ class TreeUnpickler(reader: TastyReader, // types. This came up in #137 of collection strawman. val tycon = readTpt() val args = until(end)(readTpt()) - val ownType = - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.safeAppliedTo(args.tpes) - untpd.AppliedTypeTree(tycon, args).withType(ownType) + val tree = untpd.AppliedTypeTree(tycon, args) + val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes)) + tree.withType(ownType) case ANNOTATEDtpt => Annotated(readTpt(), readTerm()) case LAMBDAtpt => diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 27c1dead4482..6ee299d34922 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -890,6 +890,24 @@ object Parsers { } } + def followingIsCaptureSet(): Boolean = + val lookahead = in.LookaheadScanner() + def recur(): Boolean = + (lookahead.isIdent || lookahead.token == THIS) && { + lookahead.nextToken() + if lookahead.token == COMMA then + lookahead.nextToken() + recur() + else + lookahead.token == RBRACE && { + lookahead.nextToken() + canStartInfixTypeTokens.contains(lookahead.token) + || lookahead.token == LBRACKET + } + } + lookahead.nextToken() + recur() + /* --------- OPERAND/OPERATOR STACK --------------------------------------- */ var opStack: List[OpInfo] = Nil @@ -1330,17 +1348,25 @@ object Parsers { case _ => false } + /** CaptureRef ::= ident | `this` + */ + def captureRef(): Tree = + if in.token == THIS then simpleRef() else termIdent() + /** Type ::= FunType * | HkTypeParamClause ‘=>>’ Type * | FunParamClause ‘=>>’ Type * | MatchType * | InfixType + * | CaptureSet Type * FunType ::= (MonoFunType | PolyFunType) * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type * PolyFunType ::= HKTypeParamClause '=>' Type * FunTypeArgs ::= InfixType * | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)' * | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')' + * CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` + * CaptureRef ::= Ident */ def typ(): Tree = { val start = in.offset @@ -1446,6 +1472,10 @@ object Parsers { } else { accept(TLARROW); typ() } } + else if in.token == LBRACE && followingIsCaptureSet() then + val refs = inBraces { commaSeparated(captureRef) } + val t = typ() + CapturingTypeTree(refs, t) else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() @@ -1514,7 +1544,7 @@ object Parsers { def infixType(): Tree = infixTypeRest(refinedType()) def infixTypeRest(t: Tree): Tree = - infixOps(t, canStartTypeTokens, refinedTypeFn, Location.ElseWhere, + infixOps(t, canStartInfixTypeTokens, refinedTypeFn, Location.ElseWhere, isType = true, isOperator = !followingIsVararg()) @@ -3154,7 +3184,7 @@ object Parsers { ImportSelector( atSpan(in.skipToken()) { Ident(nme.EMPTY) }, bound = - if canStartTypeTokens.contains(in.token) then rejectWildcardType(infixType()) + if canStartInfixTypeTokens.contains(in.token) then rejectWildcardType(infixType()) else EmptyTree) /** id [‘as’ (id | ‘_’) */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala index cba07a6e5a34..7fadf341905d 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala @@ -230,8 +230,8 @@ object Tokens extends TokensCommon { final val canStartExprTokens2: TokenSet = canStartExprTokens3 | BitSet(DO) - final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( - THIS, SUPER, USCORE, LPAREN, AT) + final val canStartInfixTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( + THIS, SUPER, USCORE, LPAREN, LBRACE, AT) final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 3a3fca5e7f90..054fe62682dd 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -14,13 +14,16 @@ import Variances.varianceSign import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch +import config.Config +import cc.{CapturingType, CaptureSet} class PlainPrinter(_ctx: Context) extends Printer { + /** The context of all public methods in Printer and subclasses. * Overridden in RefinedPrinter. */ - protected def curCtx: Context = _ctx.addMode(Mode.Printing) - protected given [DummyToEnforceDef]: Context = curCtx + def printerContext: Context = _ctx.addMode(Mode.Printing) + protected given [DummyToEnforceDef]: Context = printerContext protected def printDebug = ctx.settings.YprintDebug.value @@ -186,6 +189,22 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close + case CapturingType(parent, refs, boxed) => + def box = Str("box ") provided boxed + if printDebug && !refs.isConst then + changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent)) + else if ctx.settings.YccDebug.value then + changePrec(GlobalPrec)(box ~ refs.toText(this) ~ " " ~ toText(parent)) + else if !refs.isConst && refs.elems.isEmpty then + changePrec(GlobalPrec)("?" ~ " " ~ toText(parent)) + else if Config.printCaptureSetsAsPrefix then + changePrec(GlobalPrec)( + box ~ "{" + ~ Text(refs.elems.toList.map(toTextCaptureRef), ", ") + ~ "} " + ~ toText(parent)) + else + changePrec(InfixPrec)(toText(parent) ~ " retains " ~ box ~ toText(refs.toRetainsTypeArg)) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => "" // do not print previously reported error message because they may try to print this error type again recuresevely case tp: ErrorType => @@ -273,7 +292,7 @@ class PlainPrinter(_ctx: Context) extends Printer { /** If -uniqid is set, the unique id of symbol, after a # */ protected def idString(sym: Symbol): String = - if (showUniqueIds || Printer.debugPrintUnique) "#" + sym.id else "" + if showUniqueIds then "#" + sym.id else "" def nameString(sym: Symbol): String = simpleNameString(sym) + idString(sym) // + "<" + (if (sym.exists) sym.owner else "") + ">" @@ -313,7 +332,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case tp @ ConstantType(value) => toText(value) case pref: TermParamRef => - nameString(pref.binder.paramNames(pref.paramNum)) + nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder) case tp: RecThis => val idx = openRecs.reverse.indexOf(tp.binder) if (idx >= 0) selfRecName(idx + 1) @@ -334,6 +353,11 @@ class PlainPrinter(_ctx: Context) extends Printer { } } + def toTextCaptureRef(tp: Type): Text = + homogenize(tp) match + case tp: SingletonType => toTextRef(tp) + case _ => toText(tp) + protected def isOmittablePrefix(sym: Symbol): Boolean = defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym) diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index 550bdb94af4f..b883b6be805b 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -6,7 +6,7 @@ import core._ import Texts._, ast.Trees._ import Types.{Type, SingletonType, LambdaParam}, Symbols.Symbol, Scopes.Scope, Constants.Constant, - Names.Name, Denotations._, Annotations.Annotation + Names.Name, Denotations._, Annotations.Annotation, Contexts.Context import typer.Implicits.SearchResult import util.SourcePosition import typer.ImportInfo @@ -104,6 +104,9 @@ abstract class Printer { /** Textual representation of a prefix of some reference, ending in `.` or `#` */ def toTextPrefix(tp: Type): Text + /** Textual representation of a reference in a capture set */ + def toTextCaptureRef(tp: Type): Text + /** Textual representation of symbol's declaration */ def dclText(sym: Symbol): Text @@ -182,6 +185,9 @@ abstract class Printer { /** A plain printer without any embellishments */ def plain: Printer + + /** The context in which this printer operates */ + def printerContext: Context } object Printer { diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index ef11ec9434ae..c7606acf3dd6 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -34,11 +34,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { /** A stack of enclosing DefDef, TypeDef, or ClassDef, or ModuleDefs nodes */ private var enclosingDef: untpd.Tree = untpd.EmptyTree - private var myCtx: Context = super.curCtx + private var myCtx: Context = super.printerContext private var printPos = ctx.settings.YprintPos.value private val printLines = ctx.settings.printLines.value - override protected def curCtx: Context = myCtx + override def printerContext: Context = myCtx def withEnclosingDef(enclDef: Tree[? >: Untyped])(op: => Text): Text = { val savedCtx = myCtx @@ -164,10 +164,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(GlobalPrec) { "(" ~ keywordText("erased ").provided(info.isErasedMethod) - ~ ( if info.isParamDependent || info.isResultDependent - then paramsText(info) - else argsText(info.paramInfos) - ) + ~ paramsText(info) ~ ") " ~ arrow(info.isImplicitMethod) ~ " " @@ -245,9 +242,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty => // don't eta contract if the application would be printed specially toText(tycon) - case tp: RefinedType - if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass)) - && !printDebug => + case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug => toTextMethodAsFunction(tp.refinedInfo) case tp: TypeRef => if (tp.symbol.isAnonymousClass && !showUniqueIds) @@ -703,6 +698,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val (prefix, postfix) = if isTermHole then ("{{{ ", " }}}") else ("[[[ ", " ]]]") val argsText = toTextGlobal(args, ", ") prefix ~~ idx.toString ~~ "|" ~~ argsText ~~ postfix + case CapturingTypeTree(refs, parent) => + changePrec(GlobalPrec)("{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent)) case _ => tree.fallbackToText(this) } @@ -789,9 +786,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if mdef.hasType then Modifiers(mdef.symbol) else mdef.rawMods private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers( - sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), + sym.flagsUNSAFE & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY, - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot @@ -988,13 +985,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else if (suppressKw) PrintableFlags(isType) &~ Private else PrintableFlags(isType) if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= GivenOrImplicit // drop implicit/given from classes - val rawFlags = if (sym.exists) sym.flags else mods.flags + val rawFlags = if (sym.exists) sym.flagsUNSAFE else mods.flags if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased val flags = rawFlags & flagMask var flagsText = toTextFlags(sym, flags) val annotTexts = if sym.exists then - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _)) Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw) diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 063ba96410c8..eac14bda8d4b 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -286,7 +286,6 @@ import transform.SymUtils._ val treeStr = inTree.map(x => s"\nTree: ${x.show}").getOrElse("") treeStr + "\n" + super.explain - end TypeMismatch class NotAMember(site: Type, val name: Name, selected: String, addendum: => String = "")(using Context) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index 6e198bbeada9..595094f3edd6 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -175,6 +175,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { private val byNameMarker = marker("ByName") private val matchMarker = marker("Match") private val superMarker = marker("Super") + private val retainsMarker = marker("Retains") /** Extract the API representation of a source file */ def apiSource(tree: Tree): Seq[api.ClassLike] = { diff --git a/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala new file mode 100644 index 000000000000..9a287b2dd1d9 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala @@ -0,0 +1,19 @@ +package dotty.tools.dotc +package transform + +import core.* +import Contexts.Context +import Phases.Phase + +/** A phase that can be inserted directly after a phase that cannot + * be checked, to enable a -Ycheck as soon as possible afterwards + */ +class EmptyPhase extends Phase: + + def phaseName: String = "dummy" + + override def isEnabled(using Context) = prev.isEnabled + + override def run(using Context) = () + +end EmptyPhase \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index c7e02a5c6837..048395e8dffa 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -289,7 +289,11 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase tree.fun, tree.args.mapConserve(arg => if (methType.isImplicitMethod && arg.span.isSynthetic) - PruneErasedDefs.trivialErasedTree(arg) + arg match + case _: RefTree | _: Apply | _: TypeApply if arg.symbol.is(Erased) => + dropInlines.transform(arg) + case _ => + PruneErasedDefs.trivialErasedTree(arg) else dropInlines.transform(arg))) else tree diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 76f89cb65757..a61b736a9cc1 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -14,15 +14,22 @@ import typer.ErrorReporting.err import typer.ProtoTypes.* import typer.TypeAssigner.seqLitType import typer.ConstFold +import NamerOps.methodType import config.Printers.recheckr import util.Property import StdNames.nme import reporting.trace +object Recheck: + + /** Attachment key for rechecked types of TypeTrees */ + private val RecheckedType = Property.Key[Type] + abstract class Recheck extends Phase, IdentityDenotTransformer: thisPhase => import ast.tpd.* + import Recheck.* def preRecheckPhase = this.prev.asInstanceOf[PreRecheck] @@ -36,12 +43,17 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: override def widenSkolems = true def run(using Context): Unit = - newRechecker().checkUnit(ctx.compilationUnit) + val rechecker = newRechecker() + rechecker.transformTypes.traverse(ctx.compilationUnit.tpdTree) + rechecker.checkUnit(ctx.compilationUnit) def newRechecker()(using Context): Rechecker class Rechecker(ictx: Context): - val ta = ictx.typeAssigner + private val ta = ictx.typeAssigner + private val keepTypes = inContext(ictx) { + ictx.settings.Xprint.value.containsPhase(thisPhase) + } extension (sym: Symbol) def updateInfo(newInfo: Type)(using Context): Unit = if sym.info ne newInfo then @@ -53,23 +65,102 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: else sym.flags ).installAfter(preRecheckPhase) - /** Hook to be overridden */ - protected def reinfer(tp: Type)(using Context): Type = tp - - def reinferResult(info: Type)(using Context): Type = info match - case info: MethodOrPoly => - info.derivedLambdaType(resType = reinferResult(info.resultType)) - case _ => - reinfer(info) + extension (tpe: Type) def rememberFor(tree: Tree)(using Context): Unit = + if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then + tree.putAttachment(RecheckedType, tpe) + + def knownType(tree: Tree) = + tree.attachmentOrElse(RecheckedType, tree.tpe) + + def isUpdated(sym: Symbol)(using Context) = + val symd = sym.denot + symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd) + + def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp + + object transformTypes extends TreeTraverser: + + // Substitute parameter symbols in `from` to paramRefs in corresponding + // method or poly types `to`. We use a single BiTypeMap to do everything. + class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) + extends DeepTypeMap, BiTypeMap: + + def apply(t: Type): Type = t match + case t: NamedType => + val sym = t.symbol + def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type = + def inner(from: List[Symbol], to: List[ParamRef]): Type = + if from.isEmpty then outer(froms.tail, tos.tail) + else if sym eq from.head then to.head + else inner(from.tail, to.tail) + if tos.isEmpty then t + else inner(froms.head, tos.head.paramRefs) + outer(from, to) + case _ => + mapOver(t) + + def inverse(t: Type): Type = t match + case t: ParamRef => + def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = + if from.isEmpty then t + else if t.binder eq from.head then to.head(t.paramNum).namedType + else recur(from.tail, to.tail) + recur(to, from) + case _ => + mapOver(t) + end SubstParams + + def traverse(tree: Tree)(using Context) = + traverseChildren(tree) + tree match - def enterDef(stat: Tree)(using Context): Unit = - val sym = stat.symbol - stat match - case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] => - sym.updateInfo(reinferResult(sym.info)) - case stat: Bind => - sym.updateInfo(reinferResult(sym.info)) - case _ => + case tree: TypeTree => + transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree]).rememberFor(tree) + case tree: ValOrDefDef => + val sym = tree.symbol + + // replace an existing symbol info with inferred types + def integrateRT( + info: Type, // symbol info to replace + psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info` + prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order + prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order + ): Type = + info match + case mt: MethodOrPoly => + val psyms = psymss.head + mt.companion(mt.paramNames)( + mt1 => + if !psyms.exists(isUpdated) && !mt.isParamDependent && prevLambdas.isEmpty then + mt.paramInfos + else + val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) + psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), + mt1 => + integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) + ) + case info: ExprType => + info.derivedExprType(resType = + integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) + case _ => + val restp = knownType(tree.tpt) + if prevLambdas.isEmpty then restp + else SubstParams(prevPsymss, prevLambdas)(restp) + + if tree.tpt.hasAttachment(RecheckedType) && !sym.isConstructor then + val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) + .showing(i"update info $sym: ${sym.info} --> $result", recheckr) + if newInfo ne sym.info then + val completer = new LazyType: + def complete(denot: SymDenotation)(using Context) = + denot.info = newInfo + recheckDef(tree, sym) + sym.updateInfo(completer) + case tree: Bind => + val sym = tree.symbol + sym.updateInfo(transformType(sym.info, inferred = true)) + case _ => + end transformTypes def constFold(tree: Tree, tp: Type)(using Context): Type = val tree1 = tree.withType(tp) @@ -90,10 +181,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: excluded = if tree.symbol.is(Private) then EmptyFlags else Private ).suchThat(tree.symbol ==) constFold(tree, qualType.select(name, mbr)) + //.showing(i"recheck select $qualType . $name : ${mbr.symbol.info} = $result") def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match case Bind(name, body) => - enterDef(tree) recheck(body, pt) val sym = tree.symbol if sym.isType then sym.typeRef else sym.info @@ -104,16 +195,13 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val exprType = recheck(expr, defn.UnitType) bindType - def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type = - if !tree.rhs.isEmpty then recheck(tree.rhs, tree.symbol.info) - sym.termRef + def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) - def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type = - tree.paramss.foreach(_.foreach(enterDef)) - val rhsCtx = linkConstructorParams(sym) + def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - recheck(tree.rhs, tree.symbol.localReturnType)(using rhsCtx) - sym.termRef + inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) @@ -134,6 +222,11 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case _ => mapOver(t) formals.mapConserve(tm) + /** Hook for method type instantiation + */ + protected def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + mt.instantiate(argTypes) + def recheckApply(tree: Apply, pt: Type)(using Context): Type = recheck(tree.fun).widen match case fntpe: MethodType => @@ -153,7 +246,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: assert(formals.isEmpty) Nil val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs) - constFold(tree, fntpe.instantiate(argTypes)) + constFold(tree, instantiate(fntpe, argTypes, tree.fun.symbol)) def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type = recheck(tree.fun).widen match @@ -174,7 +267,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type = recheckStats(stats) - val exprType = recheck(expr, pt.dropIfProto) + val exprType = recheck(expr) + // The expected type `pt` is not propagated. Doing so would allow variables in the + // expected type to contain references to local symbols of the block, so the + // local symbols could escape that way. TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) def recheckBlock(tree: Block, pt: Type)(using Context): Type = @@ -195,10 +291,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckMatch(tree: Match, pt: Type)(using Context): Type = val selectorType = recheck(tree.selector) - val casesTypes = tree.cases.map(recheck(_, selectorType.widen, pt)) + val casesTypes = tree.cases.map(recheckCase(_, selectorType.widen, pt)) TypeComparer.lub(casesTypes) - def recheck(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = + def recheckCase(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = recheck(tree.pat, selType) recheck(tree.guard, defn.BooleanType) recheck(tree.body, pt) @@ -214,7 +310,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckTry(tree: Try, pt: Type)(using Context): Type = val bodyType = recheck(tree.expr, pt) - val casesTypes = tree.cases.map(recheck(_, defn.ThrowableType, pt)) + val casesTypes = tree.cases.map(recheckCase(_, defn.ThrowableType, pt)) val finalizerType = recheck(tree.finalizer, defn.UnitType) TypeComparer.lub(bodyType :: casesTypes) @@ -227,9 +323,8 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val elemTypes = tree.elems.map(recheck(_, elemProto)) seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes)) - def recheckTypeTree(tree: TypeTree)(using Context): Type = tree match - case tree: InferredTypeTree => reinfer(tree.tpe) - case _ => tree.tpe + def recheckTypeTree(tree: TypeTree)(using Context): Type = + knownType(tree) def recheckAnnotated(tree: Annotated)(using Context): Type = tree.tpe match @@ -246,14 +341,20 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: NoType def recheckStats(stats: List[Tree])(using Context): Unit = - stats.foreach(enterDef) stats.foreach(recheck(_)) + def recheckDef(tree: ValOrDefDef, sym: Symbol)(using Context): Unit = + inContext(ctx.localContext(tree, sym)) { + tree match + case tree: ValDef => recheckValDef(tree, sym) + case tree: DefDef => recheckDefDef(tree, sym) + } + /** Recheck tree without adapting it, returning its new type. * @param tree the original tree * @param pt the expected result type */ - def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + def recheckStart(tree: Tree, pt: Type = WildcardType)(using Context): Type = def recheckNamed(tree: NameTree, pt: Type)(using Context): Type = val sym = tree.symbol @@ -261,11 +362,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: Ident => recheckIdent(tree) case tree: Select => recheckSelect(tree) case tree: Bind => recheckBind(tree, pt) - case tree: ValDef => + case tree: ValOrDefDef => if tree.isEmpty then NoType - else recheckValDef(tree, sym)(using ctx.localContext(tree, sym)) - case tree: DefDef => - recheckDefDef(tree, sym)(using ctx.localContext(tree, sym)) + else + if isUpdated(sym) then sym.ensureCompleted() + else recheckDef(tree, sym) + sym.termRef case tree: TypeDef => tree.rhs match case impl: Template => @@ -295,35 +397,61 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: PackageDef => recheckPackageDef(tree) case tree: Thicket => defn.NothingType - try - val result = tree match - case tree: NameTree => recheckNamed(tree, pt) - case tree => recheckUnnamed(tree, pt) - checkConforms(result, pt, tree) - result - catch case ex: Exception => - println(i"error while rechecking $tree") - throw ex - } - end recheck + tree match + case tree: NameTree => recheckNamed(tree, pt) + case tree => recheckUnnamed(tree, pt) + end recheckStart + + def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = + checkConforms(tpe, pt, tree) + if keepTypes then tpe.rememberFor(tree) + tpe + + def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + try recheckFinish(recheckStart(tree, pt), tree, pt) + catch case ex: Exception => + println(i"error while rechecking $tree") + throw ex + } + + private val debugSuccesses = false def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match - case _: DefTree | EmptyTree | _: TypeTree => + case _: DefTree | EmptyTree | _: TypeTree | _: Closure => + // Don't report closure nodes, since their span is a point; wait instead + // for enclosing block to preduce an error case _ => val actual = tpe.widenExpr val expected = pt.widenExpr + //println(i"check conforms $actual <:< $expected") val isCompatible = actual <:< expected || expected.isRepeatedParam && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) if !isCompatible then - println(i"err at ${ctx.phase}") - err.typeMismatch(tree.withType(tpe), pt) + err.typeMismatch(tree.withType(tpe), expected) + else if debugSuccesses then + tree match + case _: Ident => + println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") + case _ => def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) end Rechecker + + override def show(tree: untpd.Tree)(using Context): String = + val addRecheckedTypes = new TreeMap: + override def transform(tree: Tree)(using Context): Tree = + val tree1 = super.transform(tree) + tree.getAttachment(RecheckedType) match + case Some(tpe) => tree1.withType(tpe) + case None => tree1 + atPhase(thisPhase) { + super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree])) + } end Recheck class TestRecheck extends Recheck: diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 29fd1adb6688..044ea11eb27e 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -375,14 +375,14 @@ class TreeChecker extends Phase with SymTransformer { val tpe = tree.typeOpt // Polymorphic apply methods stay structural until Erasure - val isPolyFunctionApply = (tree.name eq nme.apply) && (tree.qualifier.typeOpt <:< defn.PolyFunctionType) + val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass) // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then val denot = tree.denot assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation") - assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol") + assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, qualifier type = ${tree.qualifier.typeOpt}") val sym = tree.symbol val symIsFixed = tpe match { diff --git a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala index 6be58352e6dc..26bea001d1eb 100644 --- a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala @@ -70,7 +70,7 @@ class TryCatchPatterns extends MiniPhase { case _ => isDefaultCase(cdef) } - private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripAnnots match { + private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripped match { case tp @ TypeRef(pre, _) => (pre == NoPrefix || pre.typeSymbol.isStatic) && // Does not require outer class check !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 8ffe2198c4d9..7c5d34126bd9 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -148,7 +148,7 @@ object TypeTestsCasts { } case AndType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) case OrType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) - case AnnotatedType(t, _) => recur(X, t) + case tp: AnnotatedType => recur(X, tp.parent) case _: RefinedType => false case _ => true }) @@ -217,7 +217,7 @@ object TypeTestsCasts { * can be true in some cases. Issues a warning or an error otherwise. */ def checkSensical(foundClasses: List[Symbol])(using Context): Boolean = - def exprType = i"type ${expr.tpe.widen.stripAnnots}" + def exprType = i"type ${expr.tpe.widen.stripped}" def check(foundCls: Symbol): Boolean = if (!isCheckable(foundCls)) true else if (!foundCls.derivesFrom(testCls)) { diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala new file mode 100644 index 000000000000..1415016fea26 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -0,0 +1,468 @@ +package dotty.tools +package dotc +package cc + +import core._ +import Phases.*, DenotTransformers.*, SymDenotations.* +import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* +import Types._ +import Symbols._ +import StdNames._ +import Decorators._ +import config.Printers.{capt, recheckr} +import ast.{tpd, untpd, Trees} +import NameKinds.{DocArtifactName, OuterSelectName, DefaultGetterName} +import Trees._ +import scala.util.control.NonFatal +import typer.ErrorReporting._ +import typer.RefChecks +import util.Spans.Span +import util.{SimpleIdentitySet, EqHashMap, SrcPos} +import util.Chars.* +import transform.* +import transform.SymUtils.* +import scala.collection.mutable +import reporting._ +import dotty.tools.backend.jvm.DottyBackendInterface.symExtensions +import CaptureSet.{CompareResult, withCaptureSetsExplained} + +object CheckCaptures: + import ast.tpd.* + + case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer: Env): + def isOpen = !captured.isAlwaysEmpty && !isBoxed + + final class SubstParamsMap(from: BindingType, to: List[Type])(using Context) + extends ApproximatingTypeMap: + def apply(tp: Type): Type = tp match + case tp: ParamRef => + if tp.binder == from then to(tp.paramNum) else tp + case tp: NamedType => + if tp.prefix `eq` NoPrefix then tp + else tp.derivedSelect(apply(tp.prefix)) + case _: ThisType => + tp + case _ => + mapOver(tp) + + /** Check that a @retains annotation only mentions references that can be tracked + * This check is performed at Typer. + */ + def checkWellformed(ann: Tree)(using Context): Unit = + for elem <- retainedElems(ann) do + elem.tpe match + case ref: CaptureRef => + if !ref.canBeTracked then + report.error(em"$elem cannot be tracked since it is not a parameter or a local variable", elem.srcPos) + case tpe => + report.error(em"$tpe is not a legal type for a capture set", elem.srcPos) + + /** If `tp` is a capturing type, check that all references it mentions have non-empty + * capture sets. + * This check is performed after capture sets are computed in phase cc. + */ + def checkWellformedPost(tp: Type, pos: SrcPos)(using Context): Unit = tp match + case CapturingType(parent, refs, _) => + for ref <- refs.elems do + if ref.captureSetOfInfo.elems.isEmpty then + report.error(em"$ref cannot be tracked since its capture set is empty", pos) + else if parent.captureSet.accountsFor(ref) then + report.warning(em"redundant capture: $parent already accounts for $ref", pos) + case _ => + + def checkWellformedPost(ann: Tree)(using Context): Unit = + /** The lists `elems(i) :: prev.reerse :: elems(0),...,elems(i-1),elems(i+1),elems(n)` + * where `n == elems.length-1`, i <- 0..n`. + */ + def choices(prev: List[Tree], elems: List[Tree]): List[List[Tree]] = elems match + case Nil => Nil + case elem :: elems => + List(elem :: (prev reverse_::: elems)) ++ choices(elem :: prev, elems) + for case first :: others <- choices(Nil, retainedElems(ann)) do + val firstRef = first.toCaptureRef + val remaining = CaptureSet(others.map(_.toCaptureRef)*) + if remaining.accountsFor(firstRef) then + report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos) + + private inline val disallowGlobal = true + +class CheckCaptures extends Recheck: + thisPhase => + + import ast.tpd.* + import CheckCaptures.* + + def phaseName: String = "cc" + override def isEnabled(using Context) = ctx.settings.Ycc.value + + def newRechecker()(using Context) = CaptureChecker(ctx) + + override def run(using Context): Unit = + checkOverrides.traverse(ctx.compilationUnit.tpdTree) + super.run + + def checkOverrides = new TreeTraverser: + def traverse(t: Tree)(using Context) = + t match + case t: Template => + // ^^^ TODO: Can we avoid doing overrides checks twice? + // We need to do them here since only at this phase CaptureTypes are relevant + // But maybe we can then elide the check during the RefChecks phase if -Ycc is set? + RefChecks.checkAllOverrides(ctx.owner.asClass) + case _ => + traverseChildren(t) + + class CaptureChecker(ictx: Context) extends Rechecker(ictx): + import ast.tpd.* + + override def transformType(tp: Type, inferred: Boolean)(using Context): Type = + + def addInnerVars(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + tp.derivedAppliedType(tycon, args.map(addVars(_, boxed = true))) + case tp @ RefinedType(core, rname, rinfo) => + val rinfo1 = addVars(rinfo) + if defn.isFunctionType(tp) then + rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + else + tp.derivedRefinedType(addInnerVars(core), rname, rinfo1) + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(addVars(_)), + resType = addVars(tp.resType)) + case tp: PolyType => + tp.derivedLambdaType( + resType = addVars(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addVars(tp.resType)) + case _ => + tp + + def addFunctionRefinements(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + if defn.isNonRefinedFunction(tp) then + MethodType.companion( + isContextual = defn.isContextFunctionClass(tycon.classSymbol), + isErased = defn.isErasedFunctionClass(tycon.classSymbol) + )(args.init, addFunctionRefinements(args.last)) + .toFunctionType(isJava = false, alwaysDependent = true) + .showing(i"add function refinement $tp --> $result", capt) + else + tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_))) + case tp @ RefinedType(core, rname, rinfo) if !defn.isFunctionType(tp) => + tp.derivedRefinedType( + addFunctionRefinements(core), rname, addFunctionRefinements(rinfo)) + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = addFunctionRefinements(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addFunctionRefinements(tp.resType)) + case _ => + tp + + /** Refine a possibly applied class type C where the class has tracked parameters + * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } + * where CV_1, ..., CV_n are fresh capture sets. + */ + def addCaptureRefinements(tp: Type): Type = tp.stripped match + case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass => + val cls = tp.typeSymbol.asClass + cls.paramGetters.foldLeft(tp) { (core, getter) => + if getter.termRef.isTracked then + val getterType = tp.memberInfo(getter).strippedDealias + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + .showing(i"add capture refinement $tp --> $result", capt) + else + core + } + case _ => + tp + + def addVars(tp: Type, boxed: Boolean = false): Type = + var tp1 = addInnerVars(tp) + val tp2 = addCaptureRefinements(tp1) + if tp1.canHaveInferredCapture + then CapturingType(tp2, CaptureSet.Var(), boxed) + else tp2 + + if inferred then + val cleanup = new TypeMap: + def apply(t: Type) = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case _ => + mapOver(t) + addVars(addFunctionRefinements(cleanup(tp))) + .showing(i"reinfer $tp --> $result", capt) + else + val addBoxes = new TypeTraverser: + def setBoxed(t: Type) = t match + case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => + annot.tree.setBoxedCapturing() + case _ => + + def traverse(t: Type) = + t match + case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => + args.foreach(setBoxed) + case TypeBounds(lo, hi) => + setBoxed(lo); setBoxed(hi) + case _ => + traverseChildren(t) + end addBoxes + + addBoxes.traverse(tp) + tp + end transformType + + private def interpolator(using Context) = new TypeTraverser: + override def traverse(t: Type) = + t match + case CapturingType(parent, refs: CaptureSet.Var, _) => + if variance < 0 then capt.println(i"solving $t") + refs.solve(variance) + traverse(parent) + case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionOrPolyType(t) => + traverse(rinfo) + case tp: TypeVar => + case tp: TypeRef => + traverse(tp.prefix) + case _ => + traverseChildren(t) + + private def interpolateVarsIn(tpt: Tree)(using Context): Unit = + if tpt.isInstanceOf[InferredTypeTree] then + interpolator.traverse(knownType(tpt)) + .showing(i"solved vars in ${knownType(tpt)}", capt) + + private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null) + + private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap() + def capturedVars(sym: Symbol)(using Context) = + myCapturedVars.getOrElseUpdate(sym, + if sym.ownersIterator.exists(_.isTerm) then CaptureSet.Var() + else CaptureSet.empty) + + def markFree(sym: Symbol, pos: SrcPos)(using Context): Unit = + if sym.exists then + val ref = sym.termRef + def recur(env: Env): Unit = + if env.isOpen && env.owner != sym.enclosure then + capt.println(i"Mark $sym with cs ${ref.captureSet} free in ${env.owner}") + checkElem(ref, env.captured, pos) + if env.owner.isConstructor then + if env.outer.owner != sym.enclosure then recur(env.outer.outer) + else recur(env.outer) + if ref.isTracked then recur(curEnv) + + def includeCallCaptures(sym: Symbol, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + var targetSet = capturedVars(sym) + if !targetSet.isAlwaysEmpty && sym.enclosure == ownEnclosure then + targetSet = targetSet.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def includeBoxedCaptures(tp: Type, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + val targetSet = tp.boxedCaptured.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def assertSub(cs1: CaptureSet, cs2: CaptureSet)(using Context) = + assert(cs1.subCaptures(cs2, frozen = false).isOK, i"$cs1 is not a subset of $cs2") + + def checkElem(elem: CaptureRef, cs: CaptureSet, pos: SrcPos)(using Context) = + val res = elem.singletonCaptureSet.subCaptures(cs, frozen = false) + if !res.isOK then + report.error(i"$elem cannot be referenced here; it is not included in allowed capture set ${res.blocking}", pos) + + def checkSubset(cs1: CaptureSet, cs2: CaptureSet, pos: SrcPos)(using Context) = + val res = cs1.subCaptures(cs2, frozen = false) + if !res.isOK then + report.error(i"references $cs1 are not all included in allowed capture set ${res.blocking}", pos) + + override def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + val cs = capturedVars(tree.meth.symbol) + recheckr.println(i"typing closure $tree with cvs $cs") + super.recheckClosure(tree, pt).capturing(cs) + .showing(i"rechecked $tree, $result", capt) + + override def recheckIdent(tree: Ident)(using Context): Type = + markFree(tree.symbol, tree.srcPos) + if tree.symbol.is(Method) then includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckIdent(tree) + + override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + try super.recheckValDef(tree, sym) + finally + if !sym.is(Param) then + // parameters with inferred types belong to anonymous methods. We need to wait + // for more info from the context, so we cannot interpolate. Note that we cannot + // expect to have all necessary info available at the point where the anonymous + // function is compiled since we do not propagate expected types into blocks. + interpolateVarsIn(tree.tpt) + + override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val saved = curEnv + val localSet = capturedVars(sym) + if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv) + try super.recheckDefDef(tree, sym) + finally + interpolateVarsIn(tree.tpt) + curEnv = saved + + override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = + for param <- cls.paramGetters do + if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then + report.error( + "Implementation restriction: Class parameter with non-empty capture set must be a `val`", + param.srcPos) + val saved = curEnv + val localSet = capturedVars(cls) + if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv) + try super.recheckClassDef(tree, impl, cls) + finally curEnv = saved + + /** First half: Refine the type of a constructor call `new C(t_1, ..., t_n)` + * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked + * parameters of C and T_1, ..., T_m are the types of the corresponding arguments. + * + * Second half: union of all capture sets of arguments to tracked parameters. + */ + private def addParamArgRefinements(core: Type, argTypes: List[Type], cls: ClassSymbol)(using Context): (Type, CaptureSet) = + cls.paramGetters.lazyZip(argTypes).foldLeft((core, CaptureSet.empty: CaptureSet)) { (acc, refine) => + val (core, allCaptures) = acc + val (getter, argType) = refine + if getter.termRef.isTracked then + (RefinedType(core, getter.name, argType), allCaptures ++ argType.captureSet) + else + (core, allCaptures) + } + + /** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`. + * This means: + * - Instantiate result type with actual arguments + * - If call is to a constructor: + * - remember types of arguments corresponding to tracked + * parameters in refinements. + * - add capture set of instantiated class to capture set of result type. + */ + override def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + val ownType = + if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType) + else mt.resType + if sym.isConstructor then + val cls = sym.owner.asClass + val (refined, cs) = addParamArgRefinements(ownType, argTypes, cls) + refined.capturing(cs ++ capturedVars(cls) ++ capturedVars(sym)) + .showing(i"constr type $mt with $argTypes%, % in $cls = $result", capt) + else ownType + + def recheckByNameArg(tree: Tree, pt: Type)(using Context): Type = + val closureDef(mdef) = tree + val arg = mdef.rhs + val localSet = CaptureSet.Var() + curEnv = Env(mdef.symbol, localSet, isBoxed = false, curEnv) + val result = + try + inContext(ctx.withOwner(mdef.symbol)) { + recheckStart(arg, pt).capturing(localSet) + } + finally curEnv = curEnv.outer + recheckFinish(result, arg, pt) + + override def recheckApply(tree: Apply, pt: Type)(using Context): Type = + if tree.symbol == defn.cbnArg then + recheckByNameArg(tree.args(0), pt) + else + includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckApply(tree, pt) + + override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + val res = super.recheck(tree, pt) + if tree.isTerm then + includeBoxedCaptures(res, tree.srcPos) + res + + override def checkUnit(unit: CompilationUnit)(using Context): Unit = + withCaptureSetsExplained { + super.checkUnit(unit) + PostRefinerCheck.traverse(unit.tpdTree) + if ctx.settings.YccDebug.value then + show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing + } + + def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = + if disallowGlobal then + tree match + case LambdaTypeTree(_, restpt) => + checkNotGlobal(restpt, allArgs*) + case _ => + for ref <- knownType(tree).captureSet.elems do + val isGlobal = ref match + case ref: TermRef => + ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) + case _ => false + val what = if ref.isRootCapability then "universal" else "global" + if isGlobal then + val notAllowed = i" is not allowed to capture the $what capability $ref" + def msg = tree match + case tree: InferredTypeTree => + i"""inferred type argument ${knownType(tree)}$notAllowed + | + |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + case _ => s"type argument$notAllowed" + report.error(msg, tree.srcPos) + + object PostRefinerCheck extends TreeTraverser: + def traverse(tree: Tree)(using Context) = + tree match + case _: InferredTypeTree => + case tree: TypeTree if !tree.span.isZeroExtent => + knownType(tree).foreachPart( + checkWellformedPost(_, tree.srcPos)) + knownType(tree).foreachPart { + case AnnotatedType(_, annot) => + checkWellformedPost(annot.tree) + case _ => + } + case tree1 @ TypeApply(fn, args) if disallowGlobal => + for arg <- args do + //println(i"checking $arg in $tree: ${knownType(tree).captureSet}") + checkNotGlobal(arg, args*) + case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] => + val sym = t.symbol + val isLocal = + sym.ownersIterator.exists(_.isTerm) + || sym.accessBoundary(defn.RootClass).isContainedIn(sym.topLevelClass) + + // The following classes of definitions need explicit capture types ... + if !isLocal // ... since external capture types are not inferred + || sym.owner.is(Trait) // ... since we do OverridingPairs checking before capture inference + || sym.allOverriddenSymbols.nonEmpty // ... since we do override checking before capture inference + then + val inferred = knownType(t.tpt) + def checkPure(tp: Type) = tp match + case CapturingType(_, refs, _) if !refs.elems.isEmpty => + val resultStr = if t.isInstanceOf[DefDef] then " result" else "" + report.error( + em"""Non-local $sym cannot have an inferred$resultStr type + |$inferred + |with non-empty capture set $refs. + |The type needs to be declared explicitly.""", t.srcPos) + case _ => + inferred.foreachPart(checkPure, StopAt.Static) + case _ => + traverseChildren(tree) + + def postRefinerCheck(tree: tpd.Tree)(using Context): Unit = + PostRefinerCheck.traverse(tree) + + end CaptureChecker +end CheckCaptures diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 3b743906fd51..116c8ff9bbfc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -74,9 +74,8 @@ object Checking { } for (arg, which, bound) <- TypeOps.boundsViolations(args, boundss, instantiate, app) do report.error( - showInferred(DoesNotConformToBound(arg.tpe, which, bound), - app, tpt), - arg.srcPos.focus) + showInferred(DoesNotConformToBound(arg.tpe, which, bound), app, tpt), + arg.srcPos.focus) /** Check that type arguments `args` conform to corresponding bounds in `tl` * Note: This does not check the bounds of AppliedTypeTrees. These @@ -310,6 +309,7 @@ object Checking { case AndType(tp1, tp2) => isInteresting(tp1) || isInteresting(tp2) case OrType(tp1, tp2) => isInteresting(tp1) && isInteresting(tp2) case _: RefinedOrRecType | _: AppliedType => true + case tp: AnnotatedType => isInteresting(tp.parent) case _ => false } diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 7654b98995ff..de44dd0efb18 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -14,6 +14,7 @@ import Decorators._ import config.Printers.{gadts, typr, debug} import annotation.tailrec import reporting._ +import cc.{CapturingType, derivedCapturingType} import collection.mutable import scala.annotation.internal.sharable @@ -126,8 +127,8 @@ object Inferencing { couldInstantiateTypeVar(parent) case tp: AndOrType => couldInstantiateTypeVar(tp.tp1) || couldInstantiateTypeVar(tp.tp2) - case AnnotatedType(tp, _) => - couldInstantiateTypeVar(tp) + case tp: AnnotatedType => + couldInstantiateTypeVar(tp.parent) case _ => false @@ -527,6 +528,7 @@ object Inferencing { case tp: RefinedType => tp.derivedRefinedType(captureWildcards(tp.parent), tp.refinedName, tp.refinedInfo) case tp: RecType => tp.derivedRecType(captureWildcards(tp.parent)) case tp: LazyRef => captureWildcards(tp.ref) + case CapturingType(parent, refs, _) => tp.derivedCapturingType(captureWildcards(parent), refs) case tp: AnnotatedType => tp.derivedAnnotatedType(captureWildcards(tp.parent), tp.annot) case _ => tp } @@ -696,6 +698,7 @@ trait Inferencing { this: Typer => if !argType.isSingleton then argType = SkolemType(argType) argType <:< tvar case _ => + () // scala-meta complains if this is missing, but I could not mimimize further end constrainIfDependentParamRef } @@ -710,4 +713,3 @@ trait Inferencing { this: Typer => enum IfBottom: case ok, fail, flip - diff --git a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala index 4a77573a8386..aa9b84428426 100644 --- a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala +++ b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala @@ -224,7 +224,7 @@ object RefChecks { * TODO This still needs to be cleaned up; the current version is a straight port of what was there * before, but it looks too complicated and method bodies are far too large. */ - private def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { + def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { val self = clazz.thisType val upwardsSelf = upwardsThisType(clazz) var hasErrors = false diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 2a5a9ca284ac..cc72918f8040 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -15,6 +15,7 @@ import ProtoTypes._ import collection.mutable import reporting._ import Checking.{checkNoPrivateLeaks, checkNoWildcard} +import cc.CaptureSet trait TypeAssigner { import tpd.* @@ -191,6 +192,14 @@ trait TypeAssigner { if tpe.isError then tpe else errorType(ex"$whatCanNot be accessed as a member of $pre$where.$whyNot", pos) + def processAppliedType(tree: untpd.Tree, tp: Type)(using Context): Type = tp match + case AppliedType(tycon, args) => + val constr = tycon.typeSymbol + if constr == defn.andType then AndType(args(0), args(1)) + else if constr == defn.orType then OrType(args(0), args(1), soft = false) + else tp + case _ => tp + /** Type assignment method. Each method takes as parameters * - an untpd.Tree to which it assigns a type, * - typed child trees it needs to access to cpmpute that type, @@ -288,8 +297,12 @@ trait TypeAssigner { val ownType = fn.tpe.widen match { case fntpe: MethodType => if (sameLength(fntpe.paramInfos, args) || ctx.phase.prev.relaxedTyping) - if (fntpe.isResultDependent) safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) - else fntpe.resultType + if fntpe.isCaptureDependent then + fntpe.resultType.substParams(fntpe, args.tpes) + else if fntpe.isResultDependent then + safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) + else + fntpe.resultType else errorType(i"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos) case t => @@ -461,11 +474,10 @@ trait TypeAssigner { assert(!hasNamedArg(args) || ctx.reporter.errorsReported, tree) val tparams = tycon.tpe.typeParams val ownType = - if (sameLength(tparams, args)) - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.appliedTo(args.tpes) - else wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + if !sameLength(tparams, args) then + wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + else + processAppliedType(tree, tycon.tpe.appliedTo(args.tpes)) tree.withType(ownType) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8883b492e2d9..39aabe6ce6c4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -49,6 +49,7 @@ import transform.TypeUtils._ import reporting._ import Nullables._ import NullOpsDecorator._ +import cc.CheckCaptures import config.Config object Typer { @@ -1134,8 +1135,8 @@ class Typer extends Namer */ private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = { def typeTree(tp: Type) = tp match { - case _: WildcardType => untpd.TypeTree() - case _ => untpd.TypeTree(tp) + case _: WildcardType => new untpd.InferredTypeTree() + case _ => untpd.InferredTypeTree(tp) } def interpolateWildcards = new TypeMap { def apply(t: Type): Type = t match @@ -1144,7 +1145,7 @@ class Typer extends Namer case _ => mapOver(t) } - val pt1 = pt.stripTypeVar.dealias + val pt1 = pt.strippedDealias if (pt1 ne pt1.dropDependentRefinement) && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType) then @@ -2560,6 +2561,8 @@ class Typer extends Namer registerNowarn(annot1, tree) val arg1 = typed(tree.arg, pt) if (ctx.mode is Mode.Type) { + if annot1.symbol.maybeOwner == defn.RetainsAnnot then + CheckCaptures.checkWellformed(annot1) if arg1.isType then assignType(cpy.Annotated(tree)(arg1, annot1), arg1, annot1) else diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index ffca320d53d3..1fac0dac0913 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -12,8 +12,10 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def contains[E >: Elem <: AnyRef](x: E): Boolean def foreach(f: Elem => Unit): Unit def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A def toList: List[Elem] + def iterator: Iterator[Elem] final def isEmpty: Boolean = size == 0 @@ -55,8 +57,10 @@ object SimpleIdentitySet { def contains[E <: AnyRef](x: E): Boolean = false def foreach(f: Nothing => Unit): Unit = () def exists[E <: AnyRef](p: E => Boolean): Boolean = false + def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty def /: [A, E <: AnyRef](z: A)(f: (A, E) => A): A = z def toList = Nil + def iterator = Iterator.empty } private class Set1[+Elem <: AnyRef](x0: AnyRef) extends SimpleIdentitySet[Elem] { @@ -69,9 +73,12 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = f(x0.asInstanceOf[Elem]) def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set1(f(x0.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(z, x0.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: Nil + def iterator = Iterator.single(x0.asInstanceOf[Elem]) } private class Set2[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef) extends SimpleIdentitySet[Elem] { @@ -86,9 +93,15 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = { f(x0.asInstanceOf[Elem]); f(x1.asInstanceOf[Elem]) } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set2(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(2) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + } } private class Set3[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef, x2: AnyRef) extends SimpleIdentitySet[Elem] { @@ -114,9 +127,16 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) || p(x2.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set3(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem]), f(x2.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]), x2.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: x2.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(3) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + case 2 => x2.asInstanceOf[Elem] + } } private class SetN[+Elem <: AnyRef](val xs: Array[AnyRef]) extends SimpleIdentitySet[Elem] { @@ -156,6 +176,8 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = xs.asInstanceOf[Array[E]].exists(p) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + SetN(xs.map(x => f(x.asInstanceOf[Elem]).asInstanceOf[AnyRef])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = xs.asInstanceOf[Array[E]].foldLeft(z)(f) def toList: List[Elem] = { @@ -163,6 +185,7 @@ object SimpleIdentitySet { foreach(buf += _) buf.toList } + def iterator = xs.iterator.asInstanceOf[Iterator[Elem]] override def ++ [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] = that match { case that: SetN[?] => diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 7ddbe5832e5f..13f99620a449 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -39,6 +39,7 @@ class CompilationTests { compileFilesInDir("tests/pos-special/isInstanceOf", allowDeepSubtypes.and("-Xfatal-warnings")), compileFilesInDir("tests/new", defaultOptions), compileFilesInDir("tests/pos-scala2", scala2CompatMode), + compileFilesInDir("tests/pos-custom-args/captures", defaultOptions.and("-Ycc")), compileFilesInDir("tests/pos-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")), compileFilesInDir("tests/pos", defaultOptions.and("-Ysafe-init")), compileFilesInDir("tests/pos-deep-subtype", allowDeepSubtypes), @@ -136,6 +137,7 @@ class CompilationTests { compileFilesInDir("tests/neg-custom-args/allow-deep-subtypes", allowDeepSubtypes), compileFilesInDir("tests/neg-custom-args/explicit-nulls", defaultOptions.and("-Yexplicit-nulls")), compileFilesInDir("tests/neg-custom-args/no-experimental", defaultOptions.and("-Yno-experimental")), + compileFilesInDir("tests/neg-custom-args/captures", defaultOptions.and("-Ycc")), compileDir("tests/neg-custom-args/impl-conv", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions-old.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), @@ -180,6 +182,7 @@ class CompilationTests { compileFile("tests/neg-custom-args/i7314.scala", defaultOptions.and("-Xfatal-warnings", "-source", "future")), compileFile("tests/neg-custom-args/feature-shadowing.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileDir("tests/neg-custom-args/hidden-type-errors", defaultOptions.and("-explain")), + compileFile("tests/neg-custom-args/capt-wf.scala", defaultOptions.and("-Ycc", "-Xfatal-warnings")), ).checkExpectedErrors() } diff --git a/library/src-bootstrapped/scala/Retains.scala b/library/src-bootstrapped/scala/Retains.scala new file mode 100644 index 000000000000..f3bfa282a012 --- /dev/null +++ b/library/src-bootstrapped/scala/Retains.scala @@ -0,0 +1,6 @@ +package scala + +/** An annotation that indicates capture + */ +class retains(xs: Any*) extends annotation.StaticAnnotation + diff --git a/library/src-bootstrapped/scala/annotation/ability.scala b/library/src-bootstrapped/scala/annotation/ability.scala new file mode 100644 index 000000000000..8b327a2f8b02 --- /dev/null +++ b/library/src-bootstrapped/scala/annotation/ability.scala @@ -0,0 +1,9 @@ +package scala.annotation + +/** An annotation inidcating that a val should be tracked as its own ability. + * Example: + * + * @ability erased val canThrow: * = ??? + * ^^^ rename to capability + */ +class ability extends StaticAnnotation \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/Predef.scala b/library/src/scala/runtime/stdLibPatches/Predef.scala index 13dfc77ac60b..387096ab55c5 100644 --- a/library/src/scala/runtime/stdLibPatches/Predef.scala +++ b/library/src/scala/runtime/stdLibPatches/Predef.scala @@ -47,4 +47,5 @@ object Predef: */ extension [T](x: T | Null) inline def nn: x.type & T = scala.runtime.Scala3RunTime.nn(x) + end Predef diff --git a/tests/disabled/neg-custom-args/captures/capt-wf.scala b/tests/disabled/neg-custom-args/captures/capt-wf.scala new file mode 100644 index 000000000000..54fe545f443b --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/capt-wf.scala @@ -0,0 +1,19 @@ +// No longer valid +class C +type Cap = C @retains(*) +type Top = Any @retains(*) + +type T = (x: Cap) => List[String @retains(x)] => Unit // error +val x: (x: Cap) => Array[String @retains(x)] = ??? // error +val y = x + +def test: Unit = + def f(x: Cap) = // ok + val g = (xs: List[String @retains(x)]) => () + g + def f2(x: Cap)(xs: List[String @retains(x)]) = () + val x = f // error + val x2 = f2 // error + val y = f(C()) // ok + val y2 = f2(C()) // ok + () diff --git a/tests/disabled/neg-custom-args/captures/try2.check b/tests/disabled/neg-custom-args/captures/try2.check new file mode 100644 index 000000000000..c7b20d0f7c5e --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.check @@ -0,0 +1,38 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:31:32 ----------------------------------------- +31 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => Nothing + | Required: () => Nothing + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:45:2 ------------------------------------------ +45 | yy // error + | ^^ + | Found: (yy : List[(xx : (() => Int) retains canThrow)]) + | Required: List[() => Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:52:2 ------------------------------------------ +47 |val global = handle { +48 | (x: CanThrow[Exception]) => +49 | () => +50 | raise(new Exception)(using x) +51 | 22 +52 |} { // error + | ^ + | Found: (() => Int) retains canThrow + | Required: () => Int +53 | (ex: Exception) => () => 22 +54 |} + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try2.scala:24:28 -------------------------------------------------------------- +24 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the global capability (canThrow : *) +-- Error: tests/neg-custom-args/captures/try2.scala:36:11 -------------------------------------------------------------- +36 | val xx = handle { // error + | ^^^^^^ + |inferred type argument ((() => Int) retains canThrow) is not allowed to capture the global capability (canThrow : *) + | + |The inferred arguments are: [Exception, ((() => Int) retains canThrow)] diff --git a/tests/disabled/neg-custom-args/captures/try2.scala b/tests/disabled/neg-custom-args/captures/try2.scala new file mode 100644 index 000000000000..dd3cc890a197 --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.scala @@ -0,0 +1,55 @@ +// Retains syntax for classes not (yet?) supported +import language.experimental.erasedDefinitions +import annotation.ability + +@ability erased val canThrow: * = ??? + +class CanThrow[E <: Exception] extends Retains[canThrow.type] +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: List[() => Int] = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // error + +val global = handle { + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { // error + (ex: Exception) => () => 22 +} diff --git a/tests/disabled/pos/lazylist.scala b/tests/disabled/pos/lazylist.scala new file mode 100644 index 000000000000..be628113d2d8 --- /dev/null +++ b/tests/disabled/pos/lazylist.scala @@ -0,0 +1,51 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] + +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail: {*} LazyList[T] = xs() + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] = + LazyCons(x, () => xs().concat(that)) +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] = +// f(x).concat(xs().flatMap(f)) + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + def concat[U](that: {*} LazyList[U]): {that} LazyList[U] = that +// def flatMap[U](f: {*} Nothing => LazyList[U]): LazyList[U] = LazyNil + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: {cap1} LazyList[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyList[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyList[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, cap2, cap3} LazyList[Int] = ref4 \ No newline at end of file diff --git a/tests/neg/i9325.scala b/tests/neg-custom-args/allow-deep-subtypes/i9325.scala similarity index 100% rename from tests/neg/i9325.scala rename to tests/neg-custom-args/allow-deep-subtypes/i9325.scala diff --git a/tests/neg-custom-args/capt-wf.scala b/tests/neg-custom-args/capt-wf.scala new file mode 100644 index 000000000000..dc4d6a0d4bff --- /dev/null +++ b/tests/neg-custom-args/capt-wf.scala @@ -0,0 +1,35 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x1: {*} C = ??? // OK + val x2: {other} C = ??? // error: cs is empty + val s1 = () => "abc" + val x3: {s1} C = ??? // error: cs is empty + val x3a: () => String = s1 + val s2 = () => if x1 == null then "" else "abc" + val x4: {s2} C = ??? // OK + val x5: {c, c} C = ??? // error: redundant + val x6: {c} {c} C = ??? // error: redundant + val x7: {c} Cap = ??? // error: redundant + val x8: {*} {c} C = ??? // OK + val x9: {c, *} C = ??? // error: redundant + val x10: {*, c} C = ??? // error: redundant + + def even(n: Int): Boolean = if n == 0 then true else odd(n - 1) + def odd(n: Int): Boolean = if n == 1 then true else even(n - 1) + val e1 = even + val o1 = odd + + val y1: {e1} String = ??? // error cs is empty + val y2: {o1} String = ??? // error cs is empty + + lazy val ev: (Int => Boolean) = (n: Int) => + lazy val od: (Int => Boolean) = (n: Int) => + if n == 1 then true else ev(n - 1) + if n == 0 then true else od(n - 1) + val y3: {ev} String = ??? // error cs is empty + + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/bounded.scala b/tests/neg-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..dc2621e95a65 --- /dev/null +++ b/tests/neg-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: () => {c} Int => Int = r2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/boxmap.check b/tests/neg-custom-args/captures/boxmap.check new file mode 100644 index 000000000000..406077077af5 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:14:2 ---------------------------------------- +14 | () => b[Box[B]]((x: A) => box(f(x))) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {f} () => ? Box[B] + | Required: () => Box[B] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/boxmap.scala b/tests/neg-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..e335320ef9d4 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.scala @@ -0,0 +1,14 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): () => Box[B] = + () => b[Box[B]]((x: A) => box(f(x))) // error diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala new file mode 100644 index 000000000000..526cdc50952f --- /dev/null +++ b/tests/neg-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap) = + def f() = if cap1 == cap1 then g else g + def g(x: Int) = if cap2 == cap2 then 1 else x + def h(ff: => {cap2} Int => Int) = ff + h(f()) // error + + diff --git a/tests/neg-custom-args/captures/capt-box-env.scala b/tests/neg-custom-args/captures/capt-box-env.scala new file mode 100644 index 000000000000..e9743054076e --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box-env.scala @@ -0,0 +1,12 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error diff --git a/tests/neg-custom-args/captures/capt-box.scala b/tests/neg-custom-args/captures/capt-box.scala new file mode 100644 index 000000000000..317fc064ec0b --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box.scala @@ -0,0 +1,13 @@ +//import scala.retains +class C +type Cap = {*} C + +def test(x: Cap) = + + def foo(y: Cap) = if x == y then println() + + val x1 = foo + + val x2 = identity(x1) + + val x3: Cap => Unit = x2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-depfun.scala b/tests/neg-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b0beb92b313 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun.scala @@ -0,0 +1,7 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val dc: (({y, z} String) => {y, z} String) = ac(g()) // error diff --git a/tests/neg-custom-args/captures/capt-depfun2.scala b/tests/neg-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..874d753b048d --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,10 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc = ac(g()) // error: Needs explicit type Array[? >: String <: {y, z} String] + // This is a shortcoming of rechecking since the originally inferred + // type is `Array[String]` and the actual type after rechecking + // cannot be expressed as `Array[C String]` for any capture set C \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-env.scala b/tests/neg-custom-args/captures/capt-env.scala new file mode 100644 index 000000000000..84b4b57a7930 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-env.scala @@ -0,0 +1,13 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error + diff --git a/tests/neg-custom-args/captures/capt-test.scala b/tests/neg-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..0c536a280f5c --- /dev/null +++ b/tests/neg-custom-args/captures/capt-test.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: (CanThrow[E]) => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: Unit = + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } diff --git a/tests/neg-custom-args/captures/capt-wf-typer.scala b/tests/neg-custom-args/captures/capt-wf-typer.scala new file mode 100644 index 000000000000..5120e2b288d5 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-wf-typer.scala @@ -0,0 +1,10 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x7: {c} String = ??? // OK + val x8: String @retains(x7 + x7) = ??? // error + val x9: String @retains(foo) = ??? // error + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt1.check b/tests/neg-custom-args/captures/capt1.check new file mode 100644 index 000000000000..ce7c4833bf9c --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.check @@ -0,0 +1,46 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:3:2 ------------------------------------------ +3 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: () => C + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:6:2 ------------------------------------------ +6 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: Matchable + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:13:2 ----------------------------------------- +13 | def f(y: Int) = if x == null then y else y // error + | ^ + | Found: {x} Int => Int + | Required: Matchable +14 | f + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:20:2 ----------------------------------------- +20 | class F(y: Int) extends A: // error + | ^ + | Found: {x} A + | Required: A +21 | def m() = if x == null then y else y +22 | F(22) + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:25:2 ----------------------------------------- +25 | new A: // error + | ^ + | Found: {x} A + | Required: A +26 | def m() = if x == null then y else y + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ---------------------------------------- +31 | val z2 = h[() => Cap](() => x)(() => C()) // error + | ^^^^^^^ + | Found: {x} () => ? Cap + | Required: () => Cap + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/capt1.scala b/tests/neg-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..4da49c5f4f1e --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.scala @@ -0,0 +1,34 @@ +class C +def f(x: C @retains(*), y: C): () => C = + () => if x == null then y else y // error + +def g(x: C @retains(*), y: C): Matchable = + () => if x == null then y else y // error + +def h1(x: C @retains(*), y: C): Any = + def f() = if x == null then y else y + () => f() // ok + +def h2(x: C @retains(*)): Matchable = + def f(y: Int) = if x == null then y else y // error + f + +class A +type Cap = C @retains(*) + +def h3(x: Cap): A = + class F(y: Int) extends A: // error + def m() = if x == null then y else y + F(22) + +def h4(x: Cap, y: Int): A = + new A: // error + def m() = if x == null then y else y + +def foo() = + val x: C @retains(*) = ??? + def h[X](a: X)(b: X) = a + val z2 = h[() => Cap](() => x)(() => C()) // error + val z3 = h[(() => Cap) @retains(x)](() => x)(() => C()) // ok + val z4 = h[(() => Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3 + diff --git a/tests/neg-custom-args/captures/capt2.scala b/tests/neg-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..1eee53463f6d --- /dev/null +++ b/tests/neg-custom-args/captures/capt2.scala @@ -0,0 +1,9 @@ +//import scala.retains +class C +type Cap = {*} C + +def f1(c: Cap): (() => {c} C) = () => c // error, but would be OK under capture abbreciations for funciton types +def f2(c: Cap): ({c} () => C) = () => c // error + +def h5(x: Cap): () => C = + f1(x) // error diff --git a/tests/neg-custom-args/captures/capt3.scala b/tests/neg-custom-args/captures/capt3.scala new file mode 100644 index 000000000000..80b937276f73 --- /dev/null +++ b/tests/neg-custom-args/captures/capt3.scala @@ -0,0 +1,26 @@ +class C +type Cap = C @retains(*) + +def test1() = + val x: Cap = C() + val y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test2() = + val x: Cap = C() + def y = () => { x; () } + def z = y + z: (() => Unit) // error + +def test3() = + val x: Cap = C() + def y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test4() = + val x: Cap = C() + val y = () => { x; () } + def z = y + z: (() => Unit) // error diff --git a/tests/neg-custom-args/captures/cc1.scala b/tests/neg-custom-args/captures/cc1.scala new file mode 100644 index 000000000000..ebd983c58fe9 --- /dev/null +++ b/tests/neg-custom-args/captures/cc1.scala @@ -0,0 +1,4 @@ +object Test: + + def f[A <: Matchable @retains(*)](x: A): Matchable = x // error + diff --git a/tests/neg-custom-args/captures/classes.scala b/tests/neg-custom-args/captures/classes.scala new file mode 100644 index 000000000000..b87d21913d4e --- /dev/null +++ b/tests/neg-custom-args/captures/classes.scala @@ -0,0 +1,12 @@ +class B +type Cap = {*} B +class C0(n: Cap) // error: class parameter must be a `val`. + +class C(val n: Cap): + def foo(): {n} B = n + +def test(x: Cap, y: Cap) = + val c0 = C(x) + val c1: C = c0 // error + val c2 = if ??? then C(x) else /*identity*/(C(y)) // TODO: uncomment + val c3: {x} C { val n: {x, y} B } = c2 // error diff --git a/tests/neg-custom-args/captures/io.scala b/tests/neg-custom-args/captures/io.scala new file mode 100644 index 000000000000..17c22a2111e4 --- /dev/null +++ b/tests/neg-custom-args/captures/io.scala @@ -0,0 +1,22 @@ +sealed trait IO: + def puts(msg: Any): Unit = println(msg) + +def test1 = + val IO : IO @retains(*) = new IO {} + def foo = {IO; IO.puts("hello") } + val x : () => Unit = () => foo // error: Found: (() => Unit) retains IO; Required: () => Unit + +def test2 = + val IO : IO @retains(*) = new IO {} + def puts(msg: Any, io: IO @retains(*)) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + +type Capability[T] = T @retains(*) + +def test3 = + val IO : Capability[IO] = new IO {} + def puts(msg: Any, io: Capability[IO]) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check new file mode 100644 index 000000000000..3a80de9bdf16 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.check @@ -0,0 +1,42 @@ +-- [E163] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ---------------------------------------- +22 | def tail: {*} LazyList[Nothing] = ??? // error overriding + | ^ + | error overriding method tail in class LazyList of type => lazylists.LazyList[Nothing]; + | method tail of type => {*} lazylists.LazyList[Nothing] has incompatible type + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 ------------------------------------- +35 | val ref1c: LazyList[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} lazylists.LazyCons[Int]{xs: {cap1} () => {*} lazylists.LazyList[Int]}) + | Required: lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:37:36 ------------------------------------- +37 | val ref2c: {ref1} LazyList[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {ref1} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:39:36 ------------------------------------- +39 | val ref3c: {cap2} LazyList[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {cap2} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:41:48 ------------------------------------- +41 | val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap3, cap2, ref1, cap1} lazylists.LazyList[Int]) + | Required: {cap1, ref3, cap3} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/lazylist.scala:17:6 ----------------------------------------------------------- +17 | def tail = xs() // error: cannot have an inferred type + | ^^^^^^^^^^^^^^^ + | Non-local method tail cannot have an inferred result type + | {*} lazylists.LazyList[T] + | with non-empty capture set {*}. + | The type needs to be declared explicitly. diff --git a/tests/neg-custom-args/captures/lazylist.scala b/tests/neg-custom-args/captures/lazylist.scala new file mode 100644 index 000000000000..f7be43e8dc27 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.scala @@ -0,0 +1,41 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail = xs() // error: cannot have an inferred type + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail: {*} LazyList[Nothing] = ??? // error overriding + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: LazyList[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {ref1} LazyList[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {cap2} LazyList[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/lazyref.check b/tests/neg-custom-args/captures/lazyref.check new file mode 100644 index 000000000000..2affed020dec --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.check @@ -0,0 +1,28 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:19:28 -------------------------------------- +19 | val ref1c: LazyRef[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} LazyRef[Int]{elem: {cap1} () => Int}) + | Required: LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:21:35 -------------------------------------- +21 | val ref2c: {cap2} LazyRef[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap2} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:23:35 -------------------------------------- +23 | val ref3c: {ref1} LazyRef[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {ref1} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:25:35 -------------------------------------- +25 | val ref4c: {cap1} LazyRef[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap2, cap1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap1} LazyRef[Int] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazyref.scala b/tests/neg-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..1002f9685675 --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: LazyRef[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {cap2} LazyRef[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {ref1} LazyRef[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1} LazyRef[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check new file mode 100644 index 000000000000..bd95835c6525 --- /dev/null +++ b/tests/neg-custom-args/captures/try.check @@ -0,0 +1,25 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------ +28 | val b = handle[Exception, () => Nothing] { // error + | ^ + | Found: ? (x: CanThrow[Exception]) => {x} () => ? Nothing + | Required: CanThrow[Exception] => () => Nothing +29 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) +30 | } { + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try.scala:22:28 --------------------------------------------------------------- +22 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/try.scala:34:11 --------------------------------------------------------------- +34 | val xx = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] +-- Error: tests/neg-custom-args/captures/try.scala:46:13 --------------------------------------------------------------- +46 |val global = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala new file mode 100644 index 000000000000..804a16192be0 --- /dev/null +++ b/tests/neg-custom-args/captures/try.scala @@ -0,0 +1,53 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // OK + + +val global = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { + (ex: Exception) => () => 22 +} \ No newline at end of file diff --git a/tests/neg-custom-args/captures/try3.scala b/tests/neg-custom-args/captures/try3.scala new file mode 100644 index 000000000000..4fbb980b9e03 --- /dev/null +++ b/tests/neg-custom-args/captures/try3.scala @@ -0,0 +1,27 @@ +import java.io.IOException + +class CT[E] +type CanThrow[E] = {*} CT[E] +type Top = {*} Any + +def handle[E <: Exception, T <: Top](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +@main def Test: Int = + def f(a: Boolean) = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception diff --git a/tests/neg/multiLineOps.scala b/tests/neg/multiLineOps.scala index 8499cc9fe710..08a0a3925fd1 100644 --- a/tests/neg/multiLineOps.scala +++ b/tests/neg/multiLineOps.scala @@ -5,7 +5,7 @@ val x = 1 val b1 = { 22 * 22 // ok - */*one more*/22 // error: end of statement expected // error: not found: * + */*one more*/22 // error: end of statement expected } val b2: Boolean = { diff --git a/tests/neg/polymorphic-functions1.check b/tests/neg/polymorphic-functions1.check new file mode 100644 index 000000000000..86492e96dab5 --- /dev/null +++ b/tests/neg/polymorphic-functions1.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 --------------------------------------------- +1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error + | ^ + | Found: [T] => (x: Int) => Int + | Required: [T] => (x: T) => x.type + +longer explanation available when compiling with `-explain` diff --git a/tests/neg/polymorphic-functions1.scala b/tests/neg/polymorphic-functions1.scala new file mode 100644 index 000000000000..de887f3b8c50 --- /dev/null +++ b/tests/neg/polymorphic-functions1.scala @@ -0,0 +1 @@ +val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error diff --git a/tests/pos-custom-args/captures/bounded.scala b/tests/pos-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..fad0b50c2137 --- /dev/null +++ b/tests/pos-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: {c} () => {c} Int => Int = r2 \ No newline at end of file diff --git a/tests/pos-custom-args/captures/boxmap-paper.scala b/tests/pos-custom-args/captures/boxmap-paper.scala new file mode 100644 index 000000000000..ed8c648526d1 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap-paper.scala @@ -0,0 +1,38 @@ +infix type ==> [A, B] = {*} (A => B) + +type Cell[+T] = [K] => (T ==> K) => K + +def cell[T](x: T): Cell[T] = + [K] => (k: T ==> K) => k(x) + +def get[T](c: Cell[T]): T = c[T](identity) + +def map[A, B](c: Cell[A])(f: A ==> B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def pureMap[A, B](c: Cell[A])(f: A => B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def lazyMap[A, B](c: Cell[A])(f: A ==> B): {f} () => Cell[B] + = () => c[Cell[B]]((x: A) => cell(f(x))) + +trait IO: + def print(s: String): Unit + +def test(io: {*} IO) = + + val loggedOne: {io} () => Int = () => { io.print("1"); 1 } + + val c: Cell[{io} () => Int] + = cell[{io} () => Int](loggedOne) + + val g = (f: {io} () => Int) => + val x = f(); io.print(" + ") + val y = f(); io.print(s" = ${x + y}") + + val r = lazyMap[{io} () => Int, Unit](c)(f => g(f)) + val r2 = lazyMap[{io} () => Int, Unit](c)(g) + val r3 = lazyMap(c)(g) + val _ = r() + val _ = r2() + val _ = r3() diff --git a/tests/pos-custom-args/captures/boxmap.scala b/tests/pos-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..a0dcade2b179 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap.scala @@ -0,0 +1,20 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): (() => Box[B]) @retains(f) = + () => b[Box[B]]((x: A) => box(f(x))) + +def test[A <: Top, B <: Top] = + def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B) = + () => b[Box[B]]((x: A) => box(f(x))) + val x: (b: Box[A]) => (f: A ==> B) => (() => Box[B]) @retains(f) = lazymap[A, B] + () diff --git a/tests/pos-custom-args/captures/byname.scala b/tests/pos-custom-args/captures/byname.scala new file mode 100644 index 000000000000..917154079b36 --- /dev/null +++ b/tests/pos-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +class I + +def test(cap1: Cap, cap2: Cap): {cap1} I = + def f() = if cap1 == cap1 then I() else I() + def h(x: => {cap1} I) = x + h(f()) + diff --git a/tests/pos-custom-args/captures/capt-depfun.scala b/tests/pos-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b99eff32692 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun.scala @@ -0,0 +1,18 @@ +class C +type Cap = C @retains(*) + +type T = (x: Cap) => String @retains(x) + +val aa: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + +def f(y: Cap, z: Cap): String @retains(*) = + val a: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + val b = a(y) + val c: String @retains(y) = b + def g(): C @retains(y, z) = ??? + val d = a(g()) + + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val bc: (({y} String) => {y} String) = ac(y) + val dc: (String => {y, z} String) = ac(g()) + c diff --git a/tests/pos-custom-args/captures/capt-depfun2.scala b/tests/pos-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..17f98b4a1554 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,8 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc: Array[? >: String <: {y, z} String] = ac(g()) // needs to be inferred + val ec = ac(y) diff --git a/tests/pos-custom-args/captures/capt-test.scala b/tests/pos-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..f40bd2ff1746 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-test.scala @@ -0,0 +1,35 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + + val a4 = zs.map(identity) + val a4c: LIST[{d, y} Cap => Unit] = a4 diff --git a/tests/pos-custom-args/captures/capt0.scala b/tests/pos-custom-args/captures/capt0.scala new file mode 100644 index 000000000000..c8ff8a102856 --- /dev/null +++ b/tests/pos-custom-args/captures/capt0.scala @@ -0,0 +1,7 @@ +object Test: + + def test() = + val x: {*} Any = "abc" + val y: Object @scala.retains(x) = ??? + val z: Object @scala.retains(x, *) = y: Object @scala.retains(x) + diff --git a/tests/pos-custom-args/captures/capt1.scala b/tests/pos-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..14c0855544d4 --- /dev/null +++ b/tests/pos-custom-args/captures/capt1.scala @@ -0,0 +1,27 @@ +class C +type Cap = {*} C +def f1(c: Cap): {c} () => c.type = () => c // ok + +def f2: Int = + val g: {*} Boolean => Int = ??? + val x = g(true) + x + +def f3: Int = + def g: {*} Boolean => Int = ??? + def h = g + val x = g.apply(true) + x + +def foo() = + val x: {*} C = ??? + val y: {x} C = x + val x2: {x} () => C = ??? + val y2: {x} () => {x} C = x2 + + val z1: {*} () => Cap = f1(x) + def h[X](a: X)(b: X) = a + + val z2 = + if x == null then () => x else () => C() + x \ No newline at end of file diff --git a/tests/pos-custom-args/captures/capt2.scala b/tests/pos-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..e3d4cd67b30c --- /dev/null +++ b/tests/pos-custom-args/captures/capt2.scala @@ -0,0 +1,20 @@ +import scala.retains +class C +type Cap = C @retains(*) + +def test1() = + val y: {*} String = "" + def x: Object @retains(y) = y + +def test2() = + val x: Cap = C() + val y = () => { x; () } + def z: (() => Unit) @retains(x) = y + z: (() => Unit) @retains(x) + def z2: (() => Unit) @retains(y) = y + z2: (() => Unit) @retains(y) + val p: {*} () => String = () => "abc" + val q: {p} C = ??? + p: ({p} () => String) + + diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala new file mode 100644 index 000000000000..eedc95554b17 --- /dev/null +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -0,0 +1,21 @@ +object Test: + + class A + class B + class C + class CTC + type CT = CTC @retains(*) + + def test(ct: CT, dt: CT) = + + def x0: A => {ct} B = ??? + + def x1: A => B @retains(ct) = ??? + def x2: A => B => C @retains(ct) = ??? + def x3: A => () => B => C @retains(ct) = ??? + + def x4: (x: A @retains(ct)) => B => C = ??? + + def x5: A => (x: B @retains(ct)) => () => C @retains(dt) = ??? + def x6: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x, dt) = ??? + def x7: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x) = ??? \ No newline at end of file diff --git a/tests/pos-custom-args/captures/classes.scala b/tests/pos-custom-args/captures/classes.scala new file mode 100644 index 000000000000..f3d6e44b27ca --- /dev/null +++ b/tests/pos-custom-args/captures/classes.scala @@ -0,0 +1,34 @@ +class B +type Cap = {*} B +class C(val n: Cap): + this: ({n} C) => + def foo(): {n} B = n + + +def test(x: Cap, y: Cap, z: Cap) = + val c0 = C(x) + val c1: {x} C {val n: {x} B} = c0 + val d = c1.foo() + d: ({x} B) + + val c2 = if ??? then C(x) else C(y) + val c2a = identity(c2) + val c3: {x, y} C { val n: {x, y} B } = c2 + val d1 = c3.foo() + d1: B @retains(x, y) + + class Local: + + def this(a: Cap) = + this() + if a == z then println("?") + + val f = y + def foo = x + end Local + + val l = Local() + val l1: {x, y} Local = l + val l2 = Local(x) + val l3: {x, y, z} Local = l2 + diff --git a/tests/pos-custom-args/captures/iterators.scala b/tests/pos-custom-args/captures/iterators.scala new file mode 100644 index 000000000000..dd1067bcdc72 --- /dev/null +++ b/tests/pos-custom-args/captures/iterators.scala @@ -0,0 +1,23 @@ +package cctest + +abstract class Iterator[T]: + thisIterator => + + def hasNext: Boolean + def next: T + def map(f: {*} T => T): {f} Iterator[T] = new Iterator: + def hasNext = thisIterator.hasNext + def next = f(thisIterator.next) +end Iterator + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + val it = new Iterator[Int]: + private var ctr = 0 + def hasNext = ctr < 10 + def next = { ctr += 1; ctr } + + def f(x: Int): Int = if c == d then x else 10 + val it2 = it.map(f) diff --git a/tests/pos-custom-args/captures/lazyref.scala b/tests/pos-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..39748b00506b --- /dev/null +++ b/tests/pos-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: {cap1} LazyRef[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyRef[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyRef[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1, cap2} LazyRef[Int] = ref4 diff --git a/tests/pos-custom-args/captures/list-encoding.scala b/tests/pos-custom-args/captures/list-encoding.scala new file mode 100644 index 000000000000..74bc8bd2b099 --- /dev/null +++ b/tests/pos-custom-args/captures/list-encoding.scala @@ -0,0 +1,23 @@ +package listEncoding + +class Cap + +type Op[T, C] = + {*} (v: T) => {*} (s: C) => C + +type List[T] = + [C] => (op: Op[T, C]) => {op} (s: C) => C + +def nil[T]: List[T] = + [C] => (op: Op[T, C]) => (s: C) => s + +def cons[T](hd: T, tl: List[T]): List[T] = + [C] => (op: Op[T, C]) => (s: C) => op(hd)(tl(op)(s)) + +def foo(c: {*} Cap) = + def f(x: String @retains(c), y: String @retains(c)) = + cons(x, cons(y, nil)) + def g(x: String @retains(c), y: Any) = + cons(x, cons(y, nil)) + def h(x: String, y: Any @retains(c)) = + cons(x, cons(y, nil)) diff --git a/tests/pos-custom-args/captures/lists.scala b/tests/pos-custom-args/captures/lists.scala new file mode 100644 index 000000000000..139f885ec87a --- /dev/null +++ b/tests/pos-custom-args/captures/lists.scala @@ -0,0 +1,91 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + val z1 = zs.head + val z1c: {y, d} Cap => Unit = z1 + val ys1 = zs.tail + val y1 = ys1.head + + + def m1[A, B] = + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m1c: (f: {*} String => Int) => {f} LIST[String] => LIST[Int] = m1[String, Int] + + def m2 = [A, B] => + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m2c: [A, B] => (f: {*} A => B) => {f} LIST[A] => LIST[B] = m2 + + def eff[A](x: A) = if x == e then x else x + + val eff2 = [A] => (x: A) => if x == e then x else x + + val a0 = identity[{d, y} Cap => Unit] + val a0c: ({d, y} Cap => Unit) => {d, y} Cap => Unit = a0 + val a1 = zs.map[{d, y} Cap => Unit](a0) + val a1c: LIST[{d, y} Cap => Unit] = a1 + val a2 = zs.map[{d, y} Cap => Unit](identity[{d, y} Cap => Unit]) + val a2c: LIST[{d, y} Cap => Unit] = a2 + val a3 = zs.map(identity[{d, y} Cap => Unit]) + val a3c: LIST[{d, y} Cap => Unit] = a3 + val a4 = zs.map(identity) + val a4c: LIST[{d, c} Cap => Unit] = a4 + val a5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a5c: LIST[{d, c} Cap => Unit] = a5 + val a6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a6c: LIST[{d, c} Cap => Unit] = a6 + + val b0 = eff[{d, y} Cap => Unit] + val b0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = b0 + val b1 = zs.map[{d, y} Cap => Unit](a0) + val b1c: {e} LIST[{d, y} Cap => Unit] = b1 + val b2 = zs.map[{d, y} Cap => Unit](eff[{d, y} Cap => Unit]) + val b2c: {e} LIST[{d, y} Cap => Unit] = b2 + val b3 = zs.map(eff[{d, y} Cap => Unit]) + val b3c: {e} LIST[{d, y} Cap => Unit] = b3 + val b4 = zs.map(eff) + val b4c: {e} LIST[{d, c} Cap => Unit] = b4 + val b5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b5c: {e} LIST[{d, c} Cap => Unit] = b5 + val b6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b6c: {e} LIST[{d, c} Cap => Unit] = b6 + + val c0 = eff2[{d, y} Cap => Unit] + val c0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = c0 + val c1 = zs.map[{d, y} Cap => Unit](a0) + val c1c: {e} LIST[{d, y} Cap => Unit] = c1 + val c2 = zs.map[{d, y} Cap => Unit](eff2[{d, y} Cap => Unit]) + val c2c: {e} LIST[{d, y} Cap => Unit] = c2 + val c3 = zs.map(eff2[{d, y} Cap => Unit]) + val c3c: {e} LIST[{d, y} Cap => Unit] = c3 + diff --git a/tests/pos-custom-args/captures/pairs.scala b/tests/pos-custom-args/captures/pairs.scala new file mode 100644 index 000000000000..4f23a086a075 --- /dev/null +++ b/tests/pos-custom-args/captures/pairs.scala @@ -0,0 +1,33 @@ + +class C +type Cap = {*} C + +object Generic: + + class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 + +object Monomorphic: + + class Pair(val x: {*} Cap => Unit, val y: {*} Cap => Unit): + def fst = x + def snd = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 diff --git a/tests/pos-custom-args/captures/try.scala b/tests/pos-custom-args/captures/try.scala new file mode 100644 index 000000000000..a50eeabfb3a3 --- /dev/null +++ b/tests/pos-custom-args/captures/try.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R](op: (erased CanThrow[E]) => R)(handler: E => R): R = + erased val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +val _ = handle { (erased x) => + if true then + raise(new Exception)(using x) + 22 + else + 11 + } \ No newline at end of file diff --git a/tests/pos-custom-args/captures/try3.scala b/tests/pos-custom-args/captures/try3.scala new file mode 100644 index 000000000000..074517d8a9e5 --- /dev/null +++ b/tests/pos-custom-args/captures/try3.scala @@ -0,0 +1,51 @@ +import language.experimental.erasedDefinitions +import annotation.ability +import java.io.IOException + +class CT[-E] // variance is needed for correct rechecking inference +type CanThrow[E] = {*} CT[E] + +def handle[E <: Exception, T](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +def test1: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { + if !a then raise(IOException()) + (b: Boolean) => (_: CanThrow[IOException]) ?=> + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => (_: CanThrow[IOException]) ?=> -1 + } + handle { + val g = f(true) + g(false) // can raise an exception + f(true)(false) // can raise an exception + } { + ex => -1 + } +/* +def test2: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + handle { + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception + } { + ex => -1 + } +*/ \ No newline at end of file