Skip to content

Introduce boundary/break control abstraction. #16612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
94111e9
Introduce `boundary/break` control abstraction.
odersky Jan 2, 2023
a7dc51f
Make `boundary.apply` a regular inline
odersky Jan 2, 2023
5aa4b90
Fix parent of ControlException
odersky Jan 4, 2023
01ac5e0
Deprecate `NonLocalReturns` API
odersky Jan 4, 2023
d1d5301
Add loop test with exit and continue
odersky Jan 10, 2023
4610201
Make `break` non-inlined functions
odersky Jan 11, 2023
04e7bc1
Suppress returns that cross different stack sizes.
odersky Jan 11, 2023
5620e3e
Add test from errorhandling strawman
odersky Jan 11, 2023
395f2b6
Make test Scala.js-compliant
odersky Dec 26, 2022
5447a49
Another test
odersky Dec 28, 2022
3e5a678
Add validation to the strawman
odersky Dec 31, 2022
39e63f3
Adress some review comments and other tweaks
odersky Jan 12, 2023
8fb3c38
Make Label contravariant
odersky Jan 12, 2023
fbc8077
Make Label a final class
odersky Jan 12, 2023
a51a586
Add loops test with separate exit and continue boundaries
odersky Jan 12, 2023
38923f9
Another test
odersky Jan 12, 2023
861f5da
Actually test the stack size problem.
sjrd Jan 12, 2023
63f6e97
Optimize extractors
odersky Jan 12, 2023
da42cf7
Drop ControlException
odersky Jan 12, 2023
2d67e6b
Track stack height differences in the back-end for Labeled Returns.
sjrd Jan 12, 2023
009d623
DropBreaks does not track stack changes anymore.
sjrd Jan 12, 2023
0c1b9b5
Fix test
odersky Jan 12, 2023
fdad63e
Merge pull request #24 from dotty-staging/add-errorhandling-stack-han…
odersky Jan 12, 2023
2e3c4cf
Revert "Suppress returns that cross different stack sizes."
odersky Jan 12, 2023
cf8daf0
Apply suggestions from code review
odersky Jan 16, 2023
211ddb7
Deprecate NonLocalReturns object
odersky Jan 16, 2023
a8f81dd
Fix test
odersky Jan 18, 2023
335d18d
Add boundary break optimization tests
nicolasstucki Jan 10, 2023
6dded2c
Refactoring of break methiod calls
odersky Jan 18, 2023
fcef230
Fix bytecode test
odersky Jan 18, 2023
a17a6df
Drop `break` object in `scala.util`
odersky Jan 20, 2023
91bd5df
Drop `Label_this` and fix comments referring to it
odersky Jan 20, 2023
c3cf035
Make var private
odersky Jan 20, 2023
69b7a48
Update migration warning message
odersky Jan 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 74 additions & 17 deletions compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package jvm

import scala.language.unsafeNulls

import scala.annotation.switch
import scala.annotation.{switch, tailrec}
import scala.collection.mutable.SortedMap

import scala.tools.asm
Expand Down Expand Up @@ -79,9 +79,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {

tree match {
case Assign(lhs @ DesugaredSelect(qual, _), rhs) =>
val savedStackHeight = stackHeight
val isStatic = lhs.symbol.isStaticMember
if (!isStatic) { genLoadQualifier(lhs) }
if (!isStatic) {
genLoadQualifier(lhs)
stackHeight += 1
}
genLoad(rhs, symInfoTK(lhs.symbol))
stackHeight = savedStackHeight
lineNumber(tree)
// receiverClass is used in the bytecode to access the field. using sym.owner may lead to IllegalAccessError
val receiverClass = qual.tpe.typeSymbol
Expand Down Expand Up @@ -145,7 +150,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

genLoad(larg, resKind)
stackHeight += resKind.size
genLoad(rarg, if (isShift) INT else resKind)
stackHeight -= resKind.size

(code: @switch) match {
case ADD => bc add resKind
Expand Down Expand Up @@ -182,14 +189,19 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
if (isArrayGet(code)) {
// load argument on stack
assert(args.length == 1, s"Too many arguments for array get operation: $tree");
stackHeight += 1
genLoad(args.head, INT)
stackHeight -= 1
generatedType = k.asArrayBType.componentType
bc.aload(elementType)
}
else if (isArraySet(code)) {
val List(a1, a2) = args
stackHeight += 1
genLoad(a1, INT)
stackHeight += 1
genLoad(a2)
stackHeight -= 2
generatedType = UNIT
bc.astore(elementType)
} else {
Expand Down Expand Up @@ -223,7 +235,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)

val postIf = new asm.Label
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf))
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stackHeight))
markProgramPoint(failure)
genLoadTo(elsep, resKind, LoadDestination.FallThrough)
markProgramPoint(postIf)
Expand Down Expand Up @@ -482,7 +494,17 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
dest match
case LoadDestination.FallThrough =>
()
case LoadDestination.Jump(label) =>
case LoadDestination.Jump(label, targetStackHeight) =>
if targetStackHeight < stackHeight then
val stackDiff = stackHeight - targetStackHeight
if expectedType == UNIT then
bc dropMany stackDiff
else
val loc = locals.makeTempLocal(expectedType)
bc.store(loc.idx, expectedType)
bc dropMany stackDiff
bc.load(loc.idx, expectedType)
end if
bc goTo label
case LoadDestination.Return =>
bc emitRETURN returnType
Expand Down Expand Up @@ -577,7 +599,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
if dest == LoadDestination.FallThrough then
val resKind = tpeTK(tree)
val jumpTarget = new asm.Label
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget))
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stackHeight))
genLoad(expr, resKind)
markProgramPoint(jumpTarget)
resKind
Expand Down Expand Up @@ -635,7 +657,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
markProgramPoint(loop)

if isInfinite then
val dest = LoadDestination.Jump(loop)
val dest = LoadDestination.Jump(loop, stackHeight)
genLoadTo(body, UNIT, dest)
dest
else
Expand All @@ -650,7 +672,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val failure = new asm.Label
genCond(cond, success, failure, targetIfNoJump = success)
markProgramPoint(success)
genLoadTo(body, UNIT, LoadDestination.Jump(loop))
genLoadTo(body, UNIT, LoadDestination.Jump(loop, stackHeight))
markProgramPoint(failure)
end match
LoadDestination.FallThrough
Expand Down Expand Up @@ -744,7 +766,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {

// scala/bug#10290: qual can be `this.$outer()` (not just `this`), so we call genLoad (not just ALOAD_0)
genLoad(superQual)
stackHeight += 1
genLoadArguments(args, paramTKs(app))
stackHeight -= 1
generatedType = genCallMethod(fun.symbol, InvokeStyle.Super, app.span)

// 'new' constructor call: Note: since constructors are
Expand All @@ -766,7 +790,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
assert(classBTypeFromSymbol(ctor.owner) == rt, s"Symbol ${ctor.owner.showFullName} is different from $rt")
mnode.visitTypeInsn(asm.Opcodes.NEW, rt.internalName)
bc dup generatedType
stackHeight += 2
genLoadArguments(args, paramTKs(app))
stackHeight -= 2
genCallMethod(ctor, InvokeStyle.Special, app.span)

case _ =>
Expand Down Expand Up @@ -799,8 +825,12 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
else if (app.hasAttachment(BCodeHelpers.UseInvokeSpecial)) InvokeStyle.Special
else InvokeStyle.Virtual

if (invokeStyle.hasInstance) genLoadQualifier(fun)
val savedStackHeight = stackHeight
if invokeStyle.hasInstance then
genLoadQualifier(fun)
stackHeight += 1
genLoadArguments(args, paramTKs(app))
stackHeight = savedStackHeight

val DesugaredSelect(qual, name) = fun: @unchecked // fun is a Select, also checked in genLoadQualifier
val isArrayClone = name == nme.clone_ && qual.tpe.widen.isInstanceOf[JavaArrayType]
Expand Down Expand Up @@ -858,6 +888,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
bc iconst elems.length
bc newarray elmKind

stackHeight += 3 // during the genLoad below, there is the result, its dup, and the index

var i = 0
var rest = elems
while (!rest.isEmpty) {
Expand All @@ -869,6 +901,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
i = i + 1
}

stackHeight -= 3

generatedType
}

Expand All @@ -883,7 +917,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val (generatedType, postMatch, postMatchDest) =
if dest == LoadDestination.FallThrough then
val postMatch = new asm.Label
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch))
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stackHeight))
else
(expectedType, null, dest)

Expand Down Expand Up @@ -1160,14 +1194,21 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

def genLoadArguments(args: List[Tree], btpes: List[BType]): Unit =
args match
case arg :: args1 =>
btpes match
case btpe :: btpes1 =>
genLoad(arg, btpe)
genLoadArguments(args1, btpes1)
case _ =>
case _ =>
@tailrec def loop(args: List[Tree], btpes: List[BType]): Unit =
args match
case arg :: args1 =>
btpes match
case btpe :: btpes1 =>
genLoad(arg, btpe)
stackHeight += btpe.size
loop(args1, btpes1)
case _ =>
case _ =>

val savedStackHeight = stackHeight
loop(args, btpes)
stackHeight = savedStackHeight
end genLoadArguments

def genLoadModule(tree: Tree): BType = {
val module = (
Expand Down Expand Up @@ -1266,11 +1307,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}.sum
bc.genNewStringBuilder(approxBuilderSize)

stackHeight += 1 // during the genLoad below, there is a reference to the StringBuilder on the stack
for (elem <- concatArguments) {
val elemType = tpeTK(elem)
genLoad(elem, elemType)
bc.genStringBuilderAppend(elemType)
}
stackHeight -= 1

bc.genStringBuilderEnd
} else {

Expand All @@ -1287,12 +1331,15 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
var totalArgSlots = 0
var countConcats = 1 // ie. 1 + how many times we spilled

val savedStackHeight = stackHeight

for (elem <- concatArguments) {
val tpe = tpeTK(elem)
val elemSlots = tpe.size

// Unlikely spill case
if (totalArgSlots + elemSlots >= MaxIndySlots) {
stackHeight = savedStackHeight + countConcats
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())
countConcats += 1
totalArgSlots = 0
Expand All @@ -1317,8 +1364,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val tpe = tpeTK(elem)
argTypes += tpe.toASMType
genLoad(elem, tpe)
stackHeight += 1
}
}
stackHeight = savedStackHeight
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())

// If we spilled, generate one final concat
Expand Down Expand Up @@ -1513,7 +1562,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
} else {
val tk = tpeTK(l).maxType(tpeTK(r))
genLoad(l, tk)
stackHeight += tk.size
genLoad(r, tk)
stackHeight -= tk.size
genCJUMP(success, failure, op, tk, targetIfNoJump)
}
}
Expand Down Expand Up @@ -1628,7 +1679,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

genLoad(l, ObjectRef)
stackHeight += 1
genLoad(r, ObjectRef)
stackHeight -= 1
genCallMethod(equalsMethod, InvokeStyle.Static)
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
}
Expand All @@ -1644,7 +1697,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
} else if (isNonNullExpr(l)) {
// SI-7852 Avoid null check if L is statically non-null.
genLoad(l, ObjectRef)
stackHeight += 1
genLoad(r, ObjectRef)
stackHeight -= 1
genCallMethod(defn.Any_equals, InvokeStyle.Virtual)
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
} else {
Expand All @@ -1654,7 +1709,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val lNonNull = new asm.Label

genLoad(l, ObjectRef)
stackHeight += 1
genLoad(r, ObjectRef)
stackHeight -= 1
locals.store(eqEqTempLocal)
bc dup ObjectRef
genCZJUMP(lNull, lNonNull, Primitives.EQ, ObjectRef, targetIfNoJump = lNull)
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,16 @@ trait BCodeIdiomatic {
// can-multi-thread
final def drop(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.POP2 else Opcodes.POP) }

// can-multi-thread
final def dropMany(size: Int): Unit = {
var s = size
while s >= 2 do
emit(Opcodes.POP2)
s -= 2
if s > 0 then
emit(Opcodes.POP)
}

// can-multi-thread
final def dup(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.DUP2 else Opcodes.DUP) }

Expand Down
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ trait BCodeSkelBuilder extends BCodeHelpers {
/** The value is put on the stack, and control flows through to the next opcode. */
case FallThrough
/** The value is put on the stack, and control flow is transferred to the given `label`. */
case Jump(label: asm.Label)
case Jump(label: asm.Label, targetStackHeight: Int)
/** The value is RETURN'ed from the enclosing method. */
case Return
/** The value is ATHROW'n. */
Expand Down Expand Up @@ -368,6 +368,8 @@ trait BCodeSkelBuilder extends BCodeHelpers {
// used by genLoadTry() and genSynchronized()
var earlyReturnVar: Symbol = null
var shouldEmitCleanup = false
// stack tracking
var stackHeight = 0
// line numbers
var lastEmittedLineNr = -1

Expand Down Expand Up @@ -504,6 +506,13 @@ trait BCodeSkelBuilder extends BCodeHelpers {
loc
}

def makeTempLocal(tk: BType): Local =
assert(nxtIdx != -1, "not a valid start index")
assert(tk.size > 0, "makeLocal called for a symbol whose type is Unit.")
val loc = Local(tk, "temp", nxtIdx, isSynth = true)
nxtIdx += tk.size
loc

// not to be confused with `fieldStore` and `fieldLoad` which also take a symbol but a field-symbol.
def store(locSym: Symbol): Unit = {
val Local(tk, _, idx, _) = slots(locSym)
Expand Down Expand Up @@ -574,6 +583,8 @@ trait BCodeSkelBuilder extends BCodeHelpers {
earlyReturnVar = null
shouldEmitCleanup = false

stackHeight = 0

lastEmittedLineNr = -1
}

Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class Compiler {
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
new ExplicitOuter, // Add accessors to outer classes from nested ones.
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
new DropBreaks) :: // Optimize local Break throws by rewriting them
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
new InlinePatterns, // Remove placeholders of inlined patterns
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ => tree
}

def stripTyped(tree: Tree): Tree = unsplice(tree) match
case Typed(expr, _) =>
stripTyped(expr)
case _ =>
tree

/** The number of arguments in an application */
def numArgs(tree: Tree): Int = unsplice(tree) match {
case Apply(fn, args) => numArgs(fn) + args.length
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,10 @@ class Definitions {
def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass
def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions")

@tu lazy val boundaryModule: Symbol = requiredModule("scala.util.boundary")
@tu lazy val LabelClass: Symbol = requiredClass("scala.util.boundary.Label")
@tu lazy val BreakClass: Symbol = requiredClass("scala.util.boundary.Break")

@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("*")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NameKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ object NameKinds {

val LocalOptInlineLocalObj: UniqueNameKind = new UniqueNameKind("ilo")

val BoundaryName: UniqueNameKind = new UniqueNameKind("boundary")

/** The kind of names of default argument getters */
val DefaultGetterName: NumberedNameKind = new NumberedNameKind(DEFAULTGETTER, "DefaultGetter") {
def mkString(underlying: TermName, info: ThisInfo) = {
Expand Down
Loading