Skip to content

Decode variables #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import tastyquery.jdk.ClasspathLoaders
import java.nio.file.Path
import scala.util.matching.Regex
import tastyquery.Exceptions.NonMethodReferenceException
import tastyquery.SourceLanguage

object BinaryDecoder:
def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder =
Expand Down Expand Up @@ -151,7 +152,7 @@ class BinaryDecoder(using Context, ThrowOrWarn):
}
yield DecodedField.Capture(decodedClass, sym)
case Patterns.Capture(names) =>
decodedClass.symbolOpt.toSeq
decodedClass.treeOpt.toSeq
.flatMap(CaptureCollector.collectCaptures)
.filter { captureSym =>
names.exists {
Expand All @@ -177,6 +178,42 @@ class BinaryDecoder(using Context, ThrowOrWarn):
decodedFields.singleOrThrow(field)
end decode

def decode(variable: binary.Variable, sourceLine: Int): DecodedVariable =
val decodedMethod = decode(variable.declaringMethod)
decode(decodedMethod, variable, sourceLine)

def decode(decodedMethod: DecodedMethod, variable: binary.Variable, sourceLine: Int): DecodedVariable =
def tryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] =
f.applyOrElse(variable, _ => Seq.empty[DecodedVariable])

extension (xs: Seq[DecodedVariable])
def orTryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] =
if xs.nonEmpty then xs else f.applyOrElse(variable, _ => Seq.empty[DecodedVariable])
val decodedVariables = tryDecode {
case Patterns.CapturedLzyVariable(name) => decodeCapturedLzyVariable(decodedMethod, name)
case Patterns.CapturedTailLocalVariable(name) => decodeCapturedVariable(decodedMethod, name)
case Patterns.CapturedVariable(name) => decodeCapturedVariable(decodedMethod, name)
case Patterns.This() => decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _))
case Patterns.DollarThis() => decodeDollarThis(decodedMethod)
case Patterns.Proxy(name) => decodeProxy(decodedMethod, name)
case Patterns.InlinedThis() => decodeInlinedThis(decodedMethod, variable)
}.orTryDecode { case _ =>
decodedMethod match
case decodedMethod: DecodedMethod.SAMOrPartialFunctionImpl =>
decodeValDef(decodedMethod, variable, sourceLine)
.orIfEmpty(
decodeSAMOrPartialFun(
decodedMethod,
decodedMethod.implementedSymbol,
variable,
sourceLine
)
)
case _: DecodedMethod.Bridge => ignore(variable, "Bridge method")
case _ => decodeValDef(decodedMethod, variable, sourceLine)
}
decodedVariables.singleOrThrow(variable, decodedMethod)

private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] =
if syms.size > 1 then
val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _)))
Expand Down Expand Up @@ -1072,3 +1109,100 @@ class BinaryDecoder(using Context, ThrowOrWarn):
case owner: ClassSymbol => owner :: enclosingClassOwners(owner)
case owner: TermOrTypeSymbol => enclosingClassOwners(owner)
case owner: PackageSymbol => Nil

private def decodeCapturedLzyVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] =
decodedMethod match
case m: DecodedMethod.LazyInit if m.symbol.nameStr == name =>
Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol))
case m: DecodedMethod.ValOrDefDef if m.symbol.nameStr == name =>
Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol))
case _ => decodeCapturedVariable(decodedMethod, name)

private def decodeCapturedVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] =
for
metTree <- decodedMethod.treeOpt.toSeq
sym <- CaptureCollector.collectCaptures(metTree)
if name == sym.nameStr
yield DecodedVariable.CapturedVariable(decodedMethod, sym)

private def decodeValDef(
decodedMethod: DecodedMethod,
variable: binary.Variable,
sourceLine: Int
): Seq[DecodedVariable] = decodedMethod.symbolOpt match
case Some(owner: TermSymbol) if owner.sourceLanguage == SourceLanguage.Scala2 =>
for
paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten
if variable.name == paramSym.nameStr
yield DecodedVariable.ValDef(decodedMethod, paramSym)
case _ =>
for
tree <- decodedMethod.treeOpt.toSeq
localVar <- VariableCollector.collectVariables(tree).toSeq
if variable.name == localVar.sym.nameStr &&
(decodedMethod.isGenerated || !variable.declaringMethod.sourceLines.exists(x =>
localVar.sourceFile.name.endsWith(x.sourceName)
) || (localVar.startLine <= sourceLine && sourceLine <= localVar.endLine))
yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm)

private def decodeSAMOrPartialFun(
decodedMethod: DecodedMethod,
owner: TermSymbol,
variable: binary.Variable,
sourceLine: Int
): Seq[DecodedVariable] =
if owner.sourceLanguage == SourceLanguage.Scala2 then
val x =
for
paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten
if variable.name == paramSym.nameStr
yield DecodedVariable.ValDef(decodedMethod, paramSym)
x.toSeq
else
for
localVar <- owner.tree.toSeq.flatMap(t => VariableCollector.collectVariables(t))
if variable.name == localVar.sym.nameStr
yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm)

private def unexpandedSymName(sym: Symbol): String =
"(.+)\\$\\w+".r.unapplySeq(sym.nameStr).map(xs => xs(0)).getOrElse(sym.nameStr)

private def decodeProxy(
decodedMethod: DecodedMethod,
name: String
): Seq[DecodedVariable] =
for
metTree <- decodedMethod.treeOpt.toSeq
localVar <- VariableCollector.collectVariables(metTree)
if name == localVar.sym.nameStr
yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm)

private def decodeInlinedThis(
decodedMethod: DecodedMethod,
variable: binary.Variable
): Seq[DecodedVariable] =
val decodedClassSym = variable.`type` match
case cls: binary.ClassType => decode(cls).classSymbol
case _ => None
for
metTree <- decodedMethod.treeOpt.toSeq
decodedClassSym <- decodedClassSym.toSeq
if VariableCollector.collectVariables(metTree, sym = decodedMethod.symbolOpt).exists { localVar =>
localVar.sym == decodedClassSym
}
yield DecodedVariable.This(decodedMethod, decodedClassSym.thisType)

private def decodeDollarThis(
decodedMethod: DecodedMethod
): Seq[DecodedVariable] =
decodedMethod match
case _: DecodedMethod.TraitStaticForwarder =>
decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _))
case _ =>
for
classSym <- decodedMethod.owner.companionClassSymbol.toSeq
if classSym.isSubClass(defn.AnyValClass)
sym <- classSym.declarations.collect {
case sym: TermSymbol if sym.isVal && !sym.isMethod => sym
}
yield DecodedVariable.AnyValThis(decodedMethod, sym)
27 changes: 26 additions & 1 deletion src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ object DecodedMethod:
override def owner: DecodedClass = target.owner
override def declaredType: TypeOrMethodic = target.declaredType
override def symbolOpt: Option[TermSymbol] = target.symbolOpt
override def treeOpt: Option[Tree] = target.treeOpt
override def toString: String = s"AdaptedFun($target)"

final class SAMOrPartialFunctionConstructor(val owner: DecodedClass, val declaredType: Type) extends DecodedMethod:
Expand Down Expand Up @@ -191,4 +192,28 @@ object DecodedField:
override def toString: String = s"Capture($owner, ${symbol.showBasic})"

final class LazyValBitmap(val owner: DecodedClass, val declaredType: Type, val name: String) extends DecodedField:
override def toString: String = s"LazyValBitmap($owner, , ${declaredType.showBasic})"
override def toString: String = s"LazyValBitmap($owner, ${declaredType.showBasic})"

sealed trait DecodedVariable extends DecodedSymbol:
def owner: DecodedMethod
override def symbolOpt: Option[TermSymbol] = None
def declaredType: TypeOrMethodic

object DecodedVariable:
final class ValDef(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"LocalVariable($owner, ${symbol.showBasic})"

final class CapturedVariable(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"VariableCapture($owner, ${symbol.showBasic})"

final class This(val owner: DecodedMethod, val declaredType: Type) extends DecodedVariable:
override def toString: String = s"This($owner, ${declaredType.showBasic})"

final class AnyValThis(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"AnyValThis($owner, ${declaredType.showBasic})"
31 changes: 31 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@ import tastyquery.Types.*
import scala.annotation.tailrec

class StackTraceFormatter(using ThrowOrWarn):
def format(variable: DecodedVariable): String =
val typeAscription = variable.declaredType match
case tpe: Type => ": " + format(tpe)
case tpe => format(tpe)
val test1 = formatOwner(variable)
val test2 = formatName(variable)
formatName(variable) + typeAscription

def formatMethodSignatures(input: String): String = {
def dotArray(input: Array[String]): String = {
if input.length == 1 then input(0)
else input(0) + "." + dotArray(input.tail)
}
val array =
for str <- input.split('.')
yield
if str.contains(": ") then "{" + str + "}"
else str
dotArray(array)
}

def format(field: DecodedField): String =
val typeAscription = field.declaredType match
case tpe: Type => ": " + format(tpe)
Expand Down Expand Up @@ -69,6 +90,9 @@ class StackTraceFormatter(using ThrowOrWarn):
private def formatOwner(field: DecodedField): String =
format(field.owner)

private def formatOwner(variable: DecodedVariable): String =
format(variable.owner)

private def formatName(field: DecodedField): String =
field match
case field: DecodedField.ValDef => formatName(field.symbol)
Expand All @@ -79,6 +103,13 @@ class StackTraceFormatter(using ThrowOrWarn):
case field: DecodedField.Capture => formatName(field.symbol).dot("<capture>")
case field: DecodedField.LazyValBitmap => field.name.dot("<lazy val bitmap>")

private def formatName(variable: DecodedVariable): String =
variable match
case variable: DecodedVariable.ValDef => formatName(variable.symbol)
case variable: DecodedVariable.CapturedVariable => formatName(variable.symbol).dot("<capture>")
case variable: DecodedVariable.This => "this"
case variable: DecodedVariable.AnyValThis => formatName(variable.symbol)

private def formatName(method: DecodedMethod): String =
method match
case method: DecodedMethod.ValOrDefDef => formatName(method.symbol)
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/ch/epfl/scala/decoder/binary/Method.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.binary
trait Method extends Symbol:
def declaringClass: ClassType
def allParameters: Seq[Parameter]
def variables: Seq[Variable]
// return None if the class of the return type is not yet loaded
def returnType: Option[Type]
def returnTypeName: String
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.binary
trait Symbol:
def name: String
def sourceLines: Option[SourceLines]
def sourceName: Option[String] = sourceLines.map(_.sourceName)

protected def showSpan: String =
def showSpan: String =
sourceLines.map(_.showSpan).getOrElse("")
5 changes: 5 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package ch.epfl.scala.decoder.binary

trait Variable extends Symbol:
def `type`: Type
def declaringMethod: Method
6 changes: 4 additions & 2 deletions src/main/scala/ch/epfl/scala/decoder/exceptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import ch.epfl.scala.decoder.binary
case class AmbiguousException(symbol: binary.Symbol, candidates: Seq[DecodedSymbol])
extends Exception(s"Found ${candidates.size} matching symbols for ${symbol.name}")

case class NotFoundException(symbol: binary.Symbol) extends Exception(s"Cannot find binary symbol of $symbol")
case class NotFoundException(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol])
extends Exception(s"Cannot find binary symbol of $symbol")

case class IgnoredException(symbol: binary.Symbol, reason: String)
extends Exception(s"Ignored $symbol because: $reason")

case class UnexpectedException(message: String) extends Exception(message)

def notFound(symbol: binary.Symbol): Nothing = throw NotFoundException(symbol)
def notFound(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol] = None): Nothing =
throw NotFoundException(symbol, decodedOwner)

def ambiguous(symbol: binary.Symbol, candidates: Seq[DecodedSymbol]): Nothing =
throw AmbiguousException(symbol, candidates)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ import ch.epfl.scala.decoder.ThrowOrWarn
import scala.languageFeature.postfixOps

object CaptureCollector:
def collectCaptures(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn): Set[TermSymbol] =
val collector = CaptureCollector(cls)
collector.traverse(cls.tree)
def collectCaptures(tree: Tree)(using Context, ThrowOrWarn): Set[TermSymbol] =
val collector = CaptureCollector()
collector.traverse(tree)
collector.capture.toSet

class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn) extends TreeTraverser:
val capture: mutable.Set[TermSymbol] = mutable.Set.empty
val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty
class CaptureCollector(using Context, ThrowOrWarn) extends TreeTraverser:
private val capture: mutable.Set[TermSymbol] = mutable.Set.empty
private val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty

def loopCollect(symbol: Symbol)(collect: => Unit): Unit =
if !alreadySeen.contains(symbol) then
Expand All @@ -28,20 +28,16 @@ class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn
override def traverse(tree: Tree): Unit =
tree match
case _: TypeTree => ()
case valDef: ValDef =>
alreadySeen += valDef.symbol
traverse(valDef.rhs)
case bind: Bind =>
alreadySeen += bind.symbol
traverse(bind.body)
case ident: Ident =>
for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do
// check that sym is local
// and check that no owners of sym is cls
if !alreadySeen.contains(sym) then
if sym.isLocal then
if !ownersIsCls(sym) then capture += sym
if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse))
else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse))
if !alreadySeen.contains(sym) && sym.isLocal then
if !sym.isMethod then capture += sym
if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse))
else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse))
case _ => super.traverse(tree)

def ownersIsCls(sym: Symbol): Boolean =
sym.owner match
case owner: Symbol =>
if owner == cls then true
else ownersIsCls(owner)
case null => false
29 changes: 29 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,35 @@ object Patterns:
def unapply(field: binary.Field): Option[String] =
"(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0))

object CapturedLzyVariable:
def unapply(variable: binary.Variable): Option[String] =
"(.+)\\$lzy1\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0))

object CapturedVariable:
def unapply(variable: binary.Variable): Option[String] =
"(.+)\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0))

object CapturedTailLocalVariable:
def unapply(variable: binary.Variable): Option[String] =
"(.+)\\$tailLocal\\d+(\\$\\d+)?".r.unapplySeq(variable.name).map(xs => xs(0))

object This:
def unapply(variable: binary.Variable): Boolean = variable.name == "this"

object DollarThis:
def unapply(variable: binary.Variable): Boolean = variable.name == "$this"

object LazyValVariable:
def unapply(variable: binary.Variable): Option[String] =
"(.+)\\$\\d+\\$lzyVal".r.unapplySeq(variable.name).map(xs => xs(0))

object Proxy:
def unapply(variable: binary.Variable): Option[String] =
"(.+)\\$proxy\\d+".r.unapplySeq(variable.name).map(xs => xs(0))

object InlinedThis:
def unapply(variable: binary.Variable): Boolean = variable.name.endsWith("_this")

extension (field: binary.Field)
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
val extracted = field.unexpandedDecodedNames
Expand Down
Loading