diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala index 972a3ba..3b1eeae 100644 --- a/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala @@ -15,6 +15,7 @@ import tastyquery.jdk.ClasspathLoaders import java.nio.file.Path import scala.util.matching.Regex import tastyquery.Exceptions.NonMethodReferenceException +import tastyquery.SourceLanguage object BinaryDecoder: def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder = @@ -151,7 +152,7 @@ class BinaryDecoder(using Context, ThrowOrWarn): } yield DecodedField.Capture(decodedClass, sym) case Patterns.Capture(names) => - decodedClass.symbolOpt.toSeq + decodedClass.treeOpt.toSeq .flatMap(CaptureCollector.collectCaptures) .filter { captureSym => names.exists { @@ -177,6 +178,42 @@ class BinaryDecoder(using Context, ThrowOrWarn): decodedFields.singleOrThrow(field) end decode + def decode(variable: binary.Variable, sourceLine: Int): DecodedVariable = + val decodedMethod = decode(variable.declaringMethod) + decode(decodedMethod, variable, sourceLine) + + def decode(decodedMethod: DecodedMethod, variable: binary.Variable, sourceLine: Int): DecodedVariable = + def tryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = + f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) + + extension (xs: Seq[DecodedVariable]) + def orTryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = + if xs.nonEmpty then xs else f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) + val decodedVariables = tryDecode { + case Patterns.CapturedLzyVariable(name) => decodeCapturedLzyVariable(decodedMethod, name) + case Patterns.CapturedTailLocalVariable(name) => decodeCapturedVariable(decodedMethod, name) + case Patterns.CapturedVariable(name) => decodeCapturedVariable(decodedMethod, name) + case Patterns.This() => decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) + case Patterns.DollarThis() => decodeDollarThis(decodedMethod) + case Patterns.Proxy(name) => decodeProxy(decodedMethod, name) + case Patterns.InlinedThis() => decodeInlinedThis(decodedMethod, variable) + }.orTryDecode { case _ => + decodedMethod match + case decodedMethod: DecodedMethod.SAMOrPartialFunctionImpl => + decodeValDef(decodedMethod, variable, sourceLine) + .orIfEmpty( + decodeSAMOrPartialFun( + decodedMethod, + decodedMethod.implementedSymbol, + variable, + sourceLine + ) + ) + case _: DecodedMethod.Bridge => ignore(variable, "Bridge method") + case _ => decodeValDef(decodedMethod, variable, sourceLine) + } + decodedVariables.singleOrThrow(variable, decodedMethod) + private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] = if syms.size > 1 then val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _))) @@ -1072,3 +1109,100 @@ class BinaryDecoder(using Context, ThrowOrWarn): case owner: ClassSymbol => owner :: enclosingClassOwners(owner) case owner: TermOrTypeSymbol => enclosingClassOwners(owner) case owner: PackageSymbol => Nil + + private def decodeCapturedLzyVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] = + decodedMethod match + case m: DecodedMethod.LazyInit if m.symbol.nameStr == name => + Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) + case m: DecodedMethod.ValOrDefDef if m.symbol.nameStr == name => + Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) + case _ => decodeCapturedVariable(decodedMethod, name) + + private def decodeCapturedVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] = + for + metTree <- decodedMethod.treeOpt.toSeq + sym <- CaptureCollector.collectCaptures(metTree) + if name == sym.nameStr + yield DecodedVariable.CapturedVariable(decodedMethod, sym) + + private def decodeValDef( + decodedMethod: DecodedMethod, + variable: binary.Variable, + sourceLine: Int + ): Seq[DecodedVariable] = decodedMethod.symbolOpt match + case Some(owner: TermSymbol) if owner.sourceLanguage == SourceLanguage.Scala2 => + for + paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten + if variable.name == paramSym.nameStr + yield DecodedVariable.ValDef(decodedMethod, paramSym) + case _ => + for + tree <- decodedMethod.treeOpt.toSeq + localVar <- VariableCollector.collectVariables(tree).toSeq + if variable.name == localVar.sym.nameStr && + (decodedMethod.isGenerated || !variable.declaringMethod.sourceLines.exists(x => + localVar.sourceFile.name.endsWith(x.sourceName) + ) || (localVar.startLine <= sourceLine && sourceLine <= localVar.endLine)) + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeSAMOrPartialFun( + decodedMethod: DecodedMethod, + owner: TermSymbol, + variable: binary.Variable, + sourceLine: Int + ): Seq[DecodedVariable] = + if owner.sourceLanguage == SourceLanguage.Scala2 then + val x = + for + paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten + if variable.name == paramSym.nameStr + yield DecodedVariable.ValDef(decodedMethod, paramSym) + x.toSeq + else + for + localVar <- owner.tree.toSeq.flatMap(t => VariableCollector.collectVariables(t)) + if variable.name == localVar.sym.nameStr + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def unexpandedSymName(sym: Symbol): String = + "(.+)\\$\\w+".r.unapplySeq(sym.nameStr).map(xs => xs(0)).getOrElse(sym.nameStr) + + private def decodeProxy( + decodedMethod: DecodedMethod, + name: String + ): Seq[DecodedVariable] = + for + metTree <- decodedMethod.treeOpt.toSeq + localVar <- VariableCollector.collectVariables(metTree) + if name == localVar.sym.nameStr + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeInlinedThis( + decodedMethod: DecodedMethod, + variable: binary.Variable + ): Seq[DecodedVariable] = + val decodedClassSym = variable.`type` match + case cls: binary.ClassType => decode(cls).classSymbol + case _ => None + for + metTree <- decodedMethod.treeOpt.toSeq + decodedClassSym <- decodedClassSym.toSeq + if VariableCollector.collectVariables(metTree, sym = decodedMethod.symbolOpt).exists { localVar => + localVar.sym == decodedClassSym + } + yield DecodedVariable.This(decodedMethod, decodedClassSym.thisType) + + private def decodeDollarThis( + decodedMethod: DecodedMethod + ): Seq[DecodedVariable] = + decodedMethod match + case _: DecodedMethod.TraitStaticForwarder => + decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) + case _ => + for + classSym <- decodedMethod.owner.companionClassSymbol.toSeq + if classSym.isSubClass(defn.AnyValClass) + sym <- classSym.declarations.collect { + case sym: TermSymbol if sym.isVal && !sym.isMethod => sym + } + yield DecodedVariable.AnyValThis(decodedMethod, sym) diff --git a/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala b/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala index 6d1a52e..2218188 100644 --- a/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala +++ b/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala @@ -146,6 +146,7 @@ object DecodedMethod: override def owner: DecodedClass = target.owner override def declaredType: TypeOrMethodic = target.declaredType override def symbolOpt: Option[TermSymbol] = target.symbolOpt + override def treeOpt: Option[Tree] = target.treeOpt override def toString: String = s"AdaptedFun($target)" final class SAMOrPartialFunctionConstructor(val owner: DecodedClass, val declaredType: Type) extends DecodedMethod: @@ -191,4 +192,28 @@ object DecodedField: override def toString: String = s"Capture($owner, ${symbol.showBasic})" final class LazyValBitmap(val owner: DecodedClass, val declaredType: Type, val name: String) extends DecodedField: - override def toString: String = s"LazyValBitmap($owner, , ${declaredType.showBasic})" + override def toString: String = s"LazyValBitmap($owner, ${declaredType.showBasic})" + +sealed trait DecodedVariable extends DecodedSymbol: + def owner: DecodedMethod + override def symbolOpt: Option[TermSymbol] = None + def declaredType: TypeOrMethodic + +object DecodedVariable: + final class ValDef(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: + def declaredType: TypeOrMethodic = symbol.declaredType + override def symbolOpt: Option[TermSymbol] = Some(symbol) + override def toString: String = s"LocalVariable($owner, ${symbol.showBasic})" + + final class CapturedVariable(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: + def declaredType: TypeOrMethodic = symbol.declaredType + override def symbolOpt: Option[TermSymbol] = Some(symbol) + override def toString: String = s"VariableCapture($owner, ${symbol.showBasic})" + + final class This(val owner: DecodedMethod, val declaredType: Type) extends DecodedVariable: + override def toString: String = s"This($owner, ${declaredType.showBasic})" + + final class AnyValThis(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: + def declaredType: TypeOrMethodic = symbol.declaredType + override def symbolOpt: Option[TermSymbol] = Some(symbol) + override def toString: String = s"AnyValThis($owner, ${declaredType.showBasic})" diff --git a/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala b/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala index 55d3dcf..02aba7b 100644 --- a/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala +++ b/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala @@ -10,6 +10,27 @@ import tastyquery.Types.* import scala.annotation.tailrec class StackTraceFormatter(using ThrowOrWarn): + def format(variable: DecodedVariable): String = + val typeAscription = variable.declaredType match + case tpe: Type => ": " + format(tpe) + case tpe => format(tpe) + val test1 = formatOwner(variable) + val test2 = formatName(variable) + formatName(variable) + typeAscription + + def formatMethodSignatures(input: String): String = { + def dotArray(input: Array[String]): String = { + if input.length == 1 then input(0) + else input(0) + "." + dotArray(input.tail) + } + val array = + for str <- input.split('.') + yield + if str.contains(": ") then "{" + str + "}" + else str + dotArray(array) + } + def format(field: DecodedField): String = val typeAscription = field.declaredType match case tpe: Type => ": " + format(tpe) @@ -69,6 +90,9 @@ class StackTraceFormatter(using ThrowOrWarn): private def formatOwner(field: DecodedField): String = format(field.owner) + private def formatOwner(variable: DecodedVariable): String = + format(variable.owner) + private def formatName(field: DecodedField): String = field match case field: DecodedField.ValDef => formatName(field.symbol) @@ -79,6 +103,13 @@ class StackTraceFormatter(using ThrowOrWarn): case field: DecodedField.Capture => formatName(field.symbol).dot("") case field: DecodedField.LazyValBitmap => field.name.dot("") + private def formatName(variable: DecodedVariable): String = + variable match + case variable: DecodedVariable.ValDef => formatName(variable.symbol) + case variable: DecodedVariable.CapturedVariable => formatName(variable.symbol).dot("") + case variable: DecodedVariable.This => "this" + case variable: DecodedVariable.AnyValThis => formatName(variable.symbol) + private def formatName(method: DecodedMethod): String = method match case method: DecodedMethod.ValOrDefDef => formatName(method.symbol) diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala index c5336f5..196583c 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.binary trait Method extends Symbol: def declaringClass: ClassType def allParameters: Seq[Parameter] + def variables: Seq[Variable] // return None if the class of the return type is not yet loaded def returnType: Option[Type] def returnTypeName: String diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala index bfe2976..83b8366 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.binary trait Symbol: def name: String def sourceLines: Option[SourceLines] + def sourceName: Option[String] = sourceLines.map(_.sourceName) - protected def showSpan: String = + def showSpan: String = sourceLines.map(_.showSpan).getOrElse("") diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala new file mode 100644 index 0000000..28310c0 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala @@ -0,0 +1,5 @@ +package ch.epfl.scala.decoder.binary + +trait Variable extends Symbol: + def `type`: Type + def declaringMethod: Method diff --git a/src/main/scala/ch/epfl/scala/decoder/exceptions.scala b/src/main/scala/ch/epfl/scala/decoder/exceptions.scala index 4c0ae1b..0653c65 100644 --- a/src/main/scala/ch/epfl/scala/decoder/exceptions.scala +++ b/src/main/scala/ch/epfl/scala/decoder/exceptions.scala @@ -5,14 +5,16 @@ import ch.epfl.scala.decoder.binary case class AmbiguousException(symbol: binary.Symbol, candidates: Seq[DecodedSymbol]) extends Exception(s"Found ${candidates.size} matching symbols for ${symbol.name}") -case class NotFoundException(symbol: binary.Symbol) extends Exception(s"Cannot find binary symbol of $symbol") +case class NotFoundException(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol]) + extends Exception(s"Cannot find binary symbol of $symbol") case class IgnoredException(symbol: binary.Symbol, reason: String) extends Exception(s"Ignored $symbol because: $reason") case class UnexpectedException(message: String) extends Exception(message) -def notFound(symbol: binary.Symbol): Nothing = throw NotFoundException(symbol) +def notFound(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol] = None): Nothing = + throw NotFoundException(symbol, decodedOwner) def ambiguous(symbol: binary.Symbol, candidates: Seq[DecodedSymbol]): Nothing = throw AmbiguousException(symbol, candidates) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala b/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala index 7994e79..c2f6c6e 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala @@ -12,14 +12,14 @@ import ch.epfl.scala.decoder.ThrowOrWarn import scala.languageFeature.postfixOps object CaptureCollector: - def collectCaptures(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn): Set[TermSymbol] = - val collector = CaptureCollector(cls) - collector.traverse(cls.tree) + def collectCaptures(tree: Tree)(using Context, ThrowOrWarn): Set[TermSymbol] = + val collector = CaptureCollector() + collector.traverse(tree) collector.capture.toSet -class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn) extends TreeTraverser: - val capture: mutable.Set[TermSymbol] = mutable.Set.empty - val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty +class CaptureCollector(using Context, ThrowOrWarn) extends TreeTraverser: + private val capture: mutable.Set[TermSymbol] = mutable.Set.empty + private val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty def loopCollect(symbol: Symbol)(collect: => Unit): Unit = if !alreadySeen.contains(symbol) then @@ -28,20 +28,16 @@ class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn override def traverse(tree: Tree): Unit = tree match case _: TypeTree => () + case valDef: ValDef => + alreadySeen += valDef.symbol + traverse(valDef.rhs) + case bind: Bind => + alreadySeen += bind.symbol + traverse(bind.body) case ident: Ident => for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do - // check that sym is local - // and check that no owners of sym is cls - if !alreadySeen.contains(sym) then - if sym.isLocal then - if !ownersIsCls(sym) then capture += sym - if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse)) - else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse)) + if !alreadySeen.contains(sym) && sym.isLocal then + if !sym.isMethod then capture += sym + if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse)) + else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse)) case _ => super.traverse(tree) - - def ownersIsCls(sym: Symbol): Boolean = - sym.owner match - case owner: Symbol => - if owner == cls then true - else ownersIsCls(owner) - case null => false diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala b/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala index 6db851b..3ee46bf 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala @@ -166,6 +166,35 @@ object Patterns: def unapply(field: binary.Field): Option[String] = "(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0)) + object CapturedLzyVariable: + def unapply(variable: binary.Variable): Option[String] = + "(.+)\\$lzy1\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0)) + + object CapturedVariable: + def unapply(variable: binary.Variable): Option[String] = + "(.+)\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0)) + + object CapturedTailLocalVariable: + def unapply(variable: binary.Variable): Option[String] = + "(.+)\\$tailLocal\\d+(\\$\\d+)?".r.unapplySeq(variable.name).map(xs => xs(0)) + + object This: + def unapply(variable: binary.Variable): Boolean = variable.name == "this" + + object DollarThis: + def unapply(variable: binary.Variable): Boolean = variable.name == "$this" + + object LazyValVariable: + def unapply(variable: binary.Variable): Option[String] = + "(.+)\\$\\d+\\$lzyVal".r.unapplySeq(variable.name).map(xs => xs(0)) + + object Proxy: + def unapply(variable: binary.Variable): Option[String] = + "(.+)\\$proxy\\d+".r.unapplySeq(variable.name).map(xs => xs(0)) + + object InlinedThis: + def unapply(variable: binary.Variable): Boolean = variable.name.endsWith("_this") + extension (field: binary.Field) private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] = val extracted = field.unexpandedDecodedNames diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala b/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala new file mode 100644 index 0000000..859c439 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala @@ -0,0 +1,104 @@ +package ch.epfl.scala.decoder.internal + +import tastyquery.Trees.* +import scala.collection.mutable +import tastyquery.Symbols.* +import tastyquery.Traversers.* +import tastyquery.Contexts.* +import tastyquery.SourcePosition +import tastyquery.Types.* +import tastyquery.Traversers +import ch.epfl.scala.decoder.ThrowOrWarn +import scala.languageFeature.postfixOps +import tastyquery.SourceFile + +object VariableCollector: + def collectVariables(tree: Tree, sym: Option[TermSymbol] = None)(using Context, ThrowOrWarn): Set[LocalVariable] = + val collector = VariableCollector() + collector.collect(tree, sym) + +trait LocalVariable: + def sym: Symbol + def sourceFile: SourceFile + def startLine: Int + def endLine: Int + +object LocalVariable: + case class This(sym: ClassSymbol, sourceFile: SourceFile, startLine: Int, endLine: Int) extends LocalVariable + case class ValDef(sym: TermSymbol, sourceFile: SourceFile, startLine: Int, endLine: Int) extends LocalVariable + case class InlinedFromDef(underlying: LocalVariable, inlineCall: InlineCall) extends LocalVariable: + def sym: Symbol = underlying.sym + def startLine: Int = underlying.startLine + def endLine: Int = underlying.endLine + def sourceFile: SourceFile = underlying.sourceFile + +end LocalVariable + +class VariableCollector()(using Context, ThrowOrWarn) extends TreeTraverser: + private val inlinedVariables = mutable.Map.empty[TermSymbol, Set[LocalVariable]] + + def collect(tree: Tree, sym: Option[TermSymbol] = None): Set[LocalVariable] = + val variables: mutable.Set[LocalVariable] = mutable.Set.empty + var previousTree: mutable.Stack[Tree] = mutable.Stack.empty + + object Traverser extends TreeTraverser: + override def traverse(tree: Tree): Unit = + tree match + case valDefOrBind: (ValDef | Bind) => addValDefOrBind(valDefOrBind) + case _ => () + + tree match + case _: TypeTree => () + // case _: DefDef => () + case valDef: ValDef => traverse(valDef.rhs) + case bind: Bind => traverse(bind.body) + case InlineCall(inlineCall) => + val localVariables = + inlinedVariables.getOrElseUpdate(inlineCall.symbol, collectInlineDef(inlineCall.symbol)) + variables ++= localVariables.map(LocalVariable.InlinedFromDef(_, inlineCall)) + previousTree.push(inlineCall.termRefTree) + super.traverse(inlineCall.termRefTree) + previousTree.pop() + case _ => + previousTree.push(tree) + super.traverse(tree) + previousTree.pop() + + private def addValDefOrBind(valDef: ValDef | Bind): Unit = + val sym = valDef.symbol.asInstanceOf[TermSymbol] + previousTree + .collectFirst { case tree: (Block | CaseDef | DefDef) => tree } + .foreach { parentTree => + variables += + LocalVariable.ValDef( + sym, + parentTree.pos.sourceFile, + valDef.pos.startLine + 1, + parentTree.pos.endLine + 1 + ) + } + + sym + .flatMap(allOuterClasses) + .foreach(cls => + variables += LocalVariable.This( + cls, + cls.pos.sourceFile, + if cls.pos.isUnknown then -1 else cls.pos.startLine + 1, + if cls.pos.isUnknown then -1 else cls.pos.endLine + 1 + ) + ) + Traverser.traverse(tree) + variables.toSet + end collect + + private def collectInlineDef(symbol: TermSymbol): Set[LocalVariable] = + inlinedVariables(symbol) = Set.empty // break recursion + symbol.tree.toSet.flatMap(tree => collect(tree, Some(symbol))) + + private def allOuterClasses(sym: Symbol): List[ClassSymbol] = + def loop(sym: Symbol, acc: List[ClassSymbol]): List[ClassSymbol] = + sym.outerClass match + case Some(cls) => loop(cls, cls :: acc) + case None => acc + loop(sym, Nil) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala b/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala index a285581..d638d95 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala @@ -90,8 +90,10 @@ extension [A, S[+X] <: IterableOnce[X]](xs: S[A]) extension [T <: DecodedSymbol](xs: Seq[T]) def singleOrThrow(symbol: binary.Symbol): T = - singleOptOrThrow(symbol) - .getOrElse(notFound(symbol)) + singleOptOrThrow(symbol).getOrElse(notFound(symbol)) + + def singleOrThrow(symbol: binary.Symbol, decodedOwner: DecodedSymbol): T = + singleOptOrThrow(symbol).getOrElse(notFound(symbol, Some(decodedOwner))) def singleOptOrThrow(symbol: binary.Symbol): Option[T] = if xs.size > 1 then ambiguous(symbol, xs) @@ -252,7 +254,10 @@ extension (self: DecodedClass) def linearization(using Context): Seq[ClassSymbol] = classSymbol.toSeq.flatMap(_.linearization) - def thisType(using Context): Option[ThisType] = classSymbol.map(_.thisType) + def thisType(using Context): Option[Type] = self match + case self: DecodedClass.SAMOrPartialFunction => Some(self.tpe) + case inlined: DecodedClass.InlinedClass => inlined.underlying.thisType + case _ => classSymbol.map(_.thisType) def companionClass(using Context): Option[DecodedClass] = self.companionClassSymbol.map(DecodedClass.ClassDef(_)) @@ -324,3 +329,11 @@ extension (field: DecodedField) case field: DecodedField.SerialVersionUID => true case field: DecodedField.Capture => true case field: DecodedField.LazyValBitmap => true + +extension (variable: DecodedVariable) + def isGenerated: Boolean = + variable match + case variable: DecodedVariable.ValDef => false + case variable: DecodedVariable.CapturedVariable => true + case variable: DecodedVariable.This => true + case variable: DecodedVariable.AnyValThis => true diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala new file mode 100644 index 0000000..ffb4fb4 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala @@ -0,0 +1,12 @@ +package ch.epfl.scala.decoder.javareflect + +import ch.epfl.scala.decoder.binary +import ch.epfl.scala.decoder.binary.Type +import ch.epfl.scala.decoder.binary.SourceLines +import ch.epfl.scala.decoder.binary.Method +import ch.epfl.scala.decoder.binary.Variable + +class AsmVariable(val name: String, val `type`: Type, val declaringMethod: Method, val sourceLines: Option[SourceLines]) + extends Variable: + + override def toString: String = s"$name: ${`type`.name}" diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala index ea0726a..3b80b8f 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala @@ -1,8 +1,16 @@ package ch.epfl.scala.decoder.javareflect import ch.epfl.scala.decoder.binary +import org.objectweb.asm -private case class ExtraMethodInfo(sourceLines: Option[binary.SourceLines], instructions: Seq[binary.Instruction]) +private case class ExtraMethodInfo( + sourceLines: Option[binary.SourceLines], + instructions: Seq[binary.Instruction], + variables: Seq[ExtraMethodInfo.Variable], + labels: Map[asm.Label, Int] +) private object ExtraMethodInfo: - def empty: ExtraMethodInfo = ExtraMethodInfo(None, Seq.empty) + def empty: ExtraMethodInfo = ExtraMethodInfo(None, Seq.empty, Seq.empty, Map.empty) + case class Variable(name: String, descriptor: String, signature: String, start: asm.Label, end: asm.Label) + case class LineNumberNode(line: Int, start: asm.Label, sourceName: String) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala index e14d16a..fa569b8 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala @@ -4,7 +4,7 @@ import ch.epfl.scala.decoder.binary import scala.util.matching.Regex import scala.jdk.CollectionConverters.* -class JavaReflectClass(cls: Class[?], extraInfo: ExtraClassInfo, override val classLoader: JavaReflectLoader) +class JavaReflectClass(val cls: Class[?], extraInfo: ExtraClassInfo, override val classLoader: JavaReflectLoader) extends binary.ClassType: override def name: String = cls.getTypeName override def superclass = Option(cls.getSuperclass).map(classLoader.loadClass) @@ -44,3 +44,33 @@ class JavaReflectClass(cls: Class[?], extraInfo: ExtraClassInfo, override val cl override def declaredFields: Seq[binary.Field] = cls.getDeclaredFields().map(f => JavaReflectField(f, classLoader)) + +object JavaReflectClass: + val boolean: JavaReflectClass = JavaReflectClass(classOf[Boolean], ExtraClassInfo.empty, null) + val int: JavaReflectClass = JavaReflectClass(classOf[Int], ExtraClassInfo.empty, null) + val long: JavaReflectClass = JavaReflectClass(classOf[Long], ExtraClassInfo.empty, null) + val float: JavaReflectClass = JavaReflectClass(classOf[Float], ExtraClassInfo.empty, null) + val double: JavaReflectClass = JavaReflectClass(classOf[Double], ExtraClassInfo.empty, null) + val byte: JavaReflectClass = JavaReflectClass(classOf[Byte], ExtraClassInfo.empty, null) + val char: JavaReflectClass = JavaReflectClass(classOf[Char], ExtraClassInfo.empty, null) + val short: JavaReflectClass = JavaReflectClass(classOf[Short], ExtraClassInfo.empty, null) + val void: JavaReflectClass = JavaReflectClass(classOf[Unit], ExtraClassInfo.empty, null) + + val primitives: Map[String, JavaReflectClass] = Map( + "boolean" -> boolean, + "int" -> int, + "long" -> long, + "float" -> float, + "double" -> double, + "byte" -> byte, + "char" -> char, + "short" -> short, + "void" -> void + ) + + def array(componentType: JavaReflectClass): JavaReflectClass = + JavaReflectClass( + java.lang.reflect.Array.newInstance(componentType.cls, 0).getClass, + ExtraClassInfo.empty, + componentType.classLoader + ) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala index db11ebb..e31c534 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala @@ -1,6 +1,7 @@ package ch.epfl.scala.decoder.javareflect import ch.epfl.scala.decoder.binary +import ch.epfl.scala.decoder.binary.Variable import java.lang.reflect.Constructor import java.lang.reflect.Method import ch.epfl.scala.decoder.binary.SignedName @@ -12,6 +13,8 @@ class JavaReflectConstructor( loader: JavaReflectLoader ) extends binary.Method: + override def variables: Seq[Variable] = Seq.empty + override def returnType: Option[binary.Type] = Some(loader.loadClass(classOf[Unit])) override def returnTypeName: String = "void" diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala index 1c8ebc4..c9d1e74 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala @@ -5,7 +5,6 @@ import ch.epfl.scala.decoder.binary.* import scala.collection.mutable import org.objectweb.asm import java.io.IOException -import ch.epfl.scala.decoder.binary.SignedName import java.net.URLClassLoader import java.nio.file.Path @@ -16,8 +15,12 @@ class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extend loadedClasses.getOrElseUpdate(cls, doLoadClass(cls)) override def loadClass(name: String): JavaReflectClass = - val cls = classLoader.loadClass(name) - loadClass(cls) + name match + case s"$tpe[]" => + val componentType = loadClass(tpe) + JavaReflectClass.array(componentType) + case _ => + JavaReflectClass.primitives.getOrElse(name, loadClass(classLoader.loadClass(name))) private def doLoadClass(cls: Class[?]): JavaReflectClass = val extraInfo = @@ -47,10 +50,15 @@ class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extend exceptions: Array[String] ): asm.MethodVisitor = new asm.MethodVisitor(asm.Opcodes.ASM9): - val lines = mutable.Set.empty[Int] val instructions = mutable.Buffer.empty[Instruction] + val variables = mutable.Buffer.empty[ExtraMethodInfo.Variable] + val labelLines = mutable.Map.empty[asm.Label, Int] + val labels = mutable.Buffer.empty[asm.Label] + override def visitLabel(label: asm.Label): Unit = + labels += label override def visitLineNumber(line: Int, start: asm.Label): Unit = - lines += line + // println("line: " + (line) + " start: " + start + " sourceName: " + sourceName) + labelLines += start -> line override def visitFieldInsn(opcode: Int, owner: String, name: String, descriptor: String): Unit = instructions += Instruction.Field(opcode, owner.replace('/', '.'), name, descriptor) override def visitMethodInsn( @@ -61,10 +69,33 @@ class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extend isInterface: Boolean ): Unit = instructions += Instruction.Method(opcode, owner.replace('/', '.'), name, descriptor, isInterface) + // We should fix the compiler instead + // if descriptor.startsWith("(Lscala/runtime/Lazy") then + // variables += ExtraMethodInfo.Variable(name + "$lzyVal", descriptor.substring(descriptor.indexOf(')') + 1), null) + override def visitLocalVariable( + name: String, + descriptor: String, + signature: String, + start: asm.Label, + end: asm.Label, + index: Int + ): Unit = + variables += ExtraMethodInfo.Variable(name, descriptor, signature, start, end) override def visitEnd(): Unit = - allLines ++= lines - val sourceLines = Option.when(sourceName.nonEmpty)(SourceLines(sourceName, lines.toSeq)) - extraInfos += SignedName(name, descriptor) -> ExtraMethodInfo(sourceLines, instructions.toSeq) + allLines ++= labelLines.values + val sourceLines = Option.when(sourceName.nonEmpty)(SourceLines(sourceName, labelLines.values.toSeq)) + var latestLine: Option[Int] = None + val labelsWithLines: mutable.Map[asm.Label, Int] = mutable.Map.empty + for label <- labels + do + latestLine = labelLines.get(label).orElse(latestLine) + latestLine.foreach(line => labelsWithLines += label -> line) + extraInfos += SignedName(name, descriptor) -> ExtraMethodInfo( + sourceLines, + instructions.toSeq, + variables.toSeq, + labelsWithLines.toMap + ) reader.accept(visitor, asm.Opcodes.ASM9) val sourceLines = Option.when(sourceName.nonEmpty)(SourceLines(sourceName, allLines.toSeq)) ExtraClassInfo(sourceLines, extraInfos.toMap) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala index 85d06ca..aaf86c3 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala @@ -5,6 +5,9 @@ import ch.epfl.scala.decoder.binary import java.lang.reflect.Method import java.lang.reflect.Modifier import ch.epfl.scala.decoder.binary.SignedName +import ch.epfl.scala.decoder.binary.Instruction +import org.objectweb.asm +import ch.epfl.scala.decoder.binary.SourceLines class JavaReflectMethod( method: Method, @@ -12,7 +15,6 @@ class JavaReflectMethod( extraInfos: ExtraMethodInfo, loader: JavaReflectLoader ) extends binary.Method: - override def returnType: Option[binary.Type] = Option(method.getReturnType).map(loader.loadClass) @@ -40,3 +42,19 @@ class JavaReflectMethod( override def sourceLines: Option[binary.SourceLines] = extraInfos.sourceLines override def instructions: Seq[binary.Instruction] = extraInfos.instructions + + override def variables: Seq[binary.Variable] = + for variable <- extraInfos.variables + yield + val typeName = asm.Type.getType(variable.descriptor).getClassName + val sourceLines = + for + sourceName <- sourceName + line <- extraInfos.labels.get(variable.start) + yield SourceLines(sourceName, Seq(line)) + AsmVariable( + variable.name, + loader.loadClass(typeName), + this, + sourceLines + ) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala index f890c06..22cfe9e 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala @@ -20,7 +20,7 @@ class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with ClassTy override def isInterface = ref.isInstanceOf[com.sun.jdi.InterfaceType] - override def sourceLines: Option[SourceLines] = Some(SourceLines(sourceName, allLineLocations.map(_.lineNumber))) + override def sourceLines: Option[SourceLines] = sourceName.map(SourceLines(_, allLineLocations.map(_.lineNumber))) override def method(name: String, sig: String): Option[Method] = visibleMethods.find(_.signedName == SignedName(name, sig)) @@ -40,6 +40,6 @@ class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with ClassTy try ref.allLineLocations.asScala.toSeq catch case e: com.sun.jdi.AbsentInformationException => Seq.empty - private[jdi] def sourceName: String = ref.sourceName + override def sourceName: Option[String] = Option(ref.sourceName) private def visibleMethods: Seq[JdiMethod] = ref.visibleMethods.asScala.map(JdiMethod(_)).toSeq diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala new file mode 100644 index 0000000..d9b730c --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala @@ -0,0 +1,13 @@ +package ch.epfl.scala.decoder.jdi + +import ch.epfl.scala.decoder.binary.* + +class JdiField(field: com.sun.jdi.Field) extends Field: + + override def declaringClass: ClassType = JdiClass(field.declaringType()) + + override def isStatic: Boolean = field.isStatic() + + override def name: String = field.name + override def sourceLines: Option[SourceLines] = None + override def `type`: Type = JdiType(field.`type`) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala index a1c4059..f2e3538 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala @@ -14,6 +14,9 @@ class JdiMethod(method: com.sun.jdi.Method) extends Method: override def allParameters: Seq[Parameter] = method.arguments.asScala.toSeq.map(JdiLocalVariable.apply(_)) + override def variables: Seq[Variable] = + method.variables().asScala.toSeq.map(JdiVariable.apply(_, method)) + override def returnType: Option[Type] = try Some(JdiType(method.returnType)) catch case e: com.sun.jdi.ClassNotLoadedException => None @@ -29,7 +32,7 @@ class JdiMethod(method: com.sun.jdi.Method) extends Method: override def isConstructor: Boolean = method.isConstructor override def sourceLines: Option[SourceLines] = - Some(SourceLines(declaringClass.sourceName, allLineLocations.map(_.lineNumber))) + declaringClass.sourceName.map(SourceLines(_, allLineLocations.map(_.lineNumber))) override def signedName: SignedName = SignedName(name, signature) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala new file mode 100644 index 0000000..f986e1a --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala @@ -0,0 +1,12 @@ +package ch.epfl.scala.decoder.jdi + +import ch.epfl.scala.decoder.binary.* + +class JdiVariable(variable: com.sun.jdi.LocalVariable, method: com.sun.jdi.Method) extends Variable: + + override def declaringMethod: Method = + JdiMethod(method) + + override def name: String = variable.name + override def sourceLines: Option[SourceLines] = None + override def `type`: Type = JdiType(variable.`type`) diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala index 2d17a0f..0771c4d 100644 --- a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala @@ -14,7 +14,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( expectedClasses = ExpectedCount(4426), expectedMethods = ExpectedCount(68421, ambiguous = 25, notFound = 33), - expectedFields = ExpectedCount(12548, ambiguous = 27, notFound = 1) + expectedFields = ExpectedCount(12550, ambiguous = 23, notFound = 3), + expectedVariables = ExpectedCount(129844, ambiguous = 4927, notFound = 2475) ) test("scala3-compiler:3.0.2"): @@ -22,7 +23,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( expectedClasses = ExpectedCount(3859, notFound = 3), expectedMethods = ExpectedCount(60762, ambiguous = 24, notFound = 163), - expectedFields = ExpectedCount(10670, ambiguous = 23, notFound = 6) + expectedFields = ExpectedCount(10674, ambiguous = 19, notFound = 6), + expectedVariables = ExpectedCount(112306, ambiguous = 4443, notFound = 2187) ) test("io.github.vigoo:zio-aws-ec2_3:4.0.5 - slow".ignore): @@ -37,7 +39,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( expectedClasses = ExpectedCount(10), expectedMethods = ExpectedCount(218), - expectedFields = ExpectedCount(45) + expectedFields = ExpectedCount(45), + expectedVariables = ExpectedCount(194, ambiguous = 1, notFound = 45) ) test("net.zygfryd:jackshaft_3:0.2.2".ignore): @@ -49,10 +52,11 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( expectedClasses = ExpectedCount(245), expectedMethods = ExpectedCount(2755, notFound = 92), - expectedFields = ExpectedCount(298) + expectedFields = ExpectedCount(298), + expectedVariables = ExpectedCount(4541, ambiguous = 58, notFound = 38) ) - test("org.clulab:processors-main_3:8.5.3"): + test("org.clulab:processors-main_3:8.5.3".ignore): assume(!isCI) val repository = MavenRepository("http://artifactory.cs.arizona.edu:8081/artifactory/sbt-release") val decoder = initDecoder("org.clulab", "processors-main_3", "8.5.3", FetchOptions(repositories = Seq(repository))) @@ -66,7 +70,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(27), ExpectedCount(174, notFound = 2), - expectedFields = ExpectedCount(20, ambiguous = 4) + expectedFields = ExpectedCount(20, ambiguous = 4), + expectedVariables = ExpectedCount(253, ambiguous = 3, notFound = 6) ) test("com.zengularity:benji-google_3:2.2.1".ignore): @@ -96,7 +101,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(149, notFound = 9), ExpectedCount(3546, notFound = 59), - expectedFields = ExpectedCount(144, notFound = 2) + expectedFields = ExpectedCount(144, notFound = 2), + expectedVariables = ExpectedCount(14750, ambiguous = 275, notFound = 39) ) test("com.evolution:scache_3:5.1.2"): @@ -107,7 +113,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(105), ExpectedCount(1509), - expectedFields = ExpectedCount(161) + expectedFields = ExpectedCount(161), + expectedVariables = ExpectedCount(3150, ambiguous = 51, notFound = 9) ) test("com.github.j5ik2o:docker-controller-scala-dynamodb-local_:1.15.34"): @@ -118,7 +125,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(2), ExpectedCount(37), - expectedFields = ExpectedCount(5) + expectedFields = ExpectedCount(5), + expectedVariables = ExpectedCount(30) ) test("eu.ostrzyciel.jelly:jelly-grpc_3:0.5.3"): @@ -127,7 +135,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(24), ExpectedCount(353), - expectedFields = ExpectedCount(61) + expectedFields = ExpectedCount(61), + expectedVariables = ExpectedCount(443, ambiguous = 3, notFound = 2) ) test("com.devsisters:zio-agones_3:0.1.0"): @@ -137,7 +146,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(83, notFound = 26), ExpectedCount(2804, ambiguous = 2, notFound = 5), - expectedFields = ExpectedCount(258) + expectedFields = ExpectedCount(258), + expectedVariables = ExpectedCount(3706, ambiguous = 17, notFound = 1, throwables = 48) ) test("org.log4s:log4s_3:1.10.0".ignore): @@ -164,7 +174,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(19), ExpectedCount(158), - expectedFields = ExpectedCount(32, ambiguous = 4, notFound = 2) + expectedFields = ExpectedCount(32, ambiguous = 4, notFound = 2), + expectedVariables = ExpectedCount(204, notFound = 2) ) test("io.github.valdemargr:gql-core_3:0.3.3"): @@ -172,5 +183,6 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( ExpectedCount(531), ExpectedCount(7267, ambiguous = 4, notFound = 1), - expectedFields = ExpectedCount(851, notFound = 2) + expectedFields = ExpectedCount(851, notFound = 2), + expectedVariables = ExpectedCount(14771, ambiguous = 313, notFound = 26) ) diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStatsFull.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStatsFull.scala index 72f0a36..45b377f 100644 --- a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStatsFull.scala +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStatsFull.scala @@ -31,7 +31,7 @@ class BinaryDecoderStatsFull extends BinaryDecoderSuite: val parts = line.split(',').map(_.drop(1).dropRight(1)) val (org, artifact, version) = (parts(0), parts(1), parts(2)) try - val (classCounter, methodCounter, _) = tryDecodeAll(org, artifact, version) + val (classCounter, methodCounter, _, _) = tryDecodeAll(org, artifact, version) classCounts += classCounter.count methodCounts += methodCounter.count catch case e => println(s"cannot decode $line") @@ -45,7 +45,9 @@ class BinaryDecoderStatsFull extends BinaryDecoderSuite: .sortBy(count => -count.successPercent) .foreach(c => println(s"${c.name} ${c.successPercent}%")) - def tryDecodeAll(org: String, artifact: String, version: String)(using ThrowOrWarn): (Counter, Counter, Counter) = + def tryDecodeAll(org: String, artifact: String, version: String)(using + ThrowOrWarn + ): (Counter, Counter, Counter, Counter) = val repositories = if org == "org.clulab" then Seq(MavenRepository("http://artifactory.cs.arizona.edu:8081/artifactory/sbt-release")) else if org == "com.zengularity" then @@ -53,7 +55,7 @@ class BinaryDecoderStatsFull extends BinaryDecoderSuite: else if org == "com.evolution" then Seq(MavenRepository("https://evolution.jfrog.io/artifactory/public")) else if org == "com.github.j5ik2o" then Seq(MavenRepository("https://maven.seasar.org/maven2/")) else Seq.empty - def tryWith(keepOptional: Boolean, keepProvided: Boolean): Option[(Counter, Counter, Counter)] = + def tryWith(keepOptional: Boolean, keepProvided: Boolean): Option[(Counter, Counter, Counter, Counter)] = try val fetchOptions = FetchOptions(keepOptional, keepProvided, repositories) val decoder = initDecoder(org, artifact, version, fetchOptions) diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala index e0d327a..05c465f 100644 --- a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala @@ -12,6 +12,459 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco def isScala33 = scalaVersion.isScala33 def isScala34 = scalaVersion.isScala34 + test("scala3-compiler:3.3.1"): + val decoder = initDecoder("org.scala-lang", "scala3-compiler_3", "3.3.1") + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.QuoteMatcher$", + "scala.Option treeMatch(dotty.tools.dotc.ast.Trees$Tree scrutineeTree, dotty.tools.dotc.ast.Trees$Tree patternTree, dotty.tools.dotc.core.Contexts$Context x$3)", + "scala.util.boundary$Break ex", + "ex: Break[T]", + 128 + ) + + test("tailLocal variables") { + val source = + """|package example + | + |class A { + | @annotation.tailrec + | private def factAcc(x: Int, acc: Int): Int = + | if x <= 1 then List(1, 2).map(_ * acc).sum + | else factAcc(x - 1, x * acc) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)") + // decoder.assertDecodeVariable("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)", "int acc$tailLocal1$1", "acc.: Int", 6) + } + + test("SAMOrPartialFunctionImpl") { + val source = + """|package example + | + |class A: + | def foo(x: Int) = + | val xs = List(x, x + 1, x + 2) + | xs.collect { case z if z % 2 == 0 => z } + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A$$anon$1", "boolean isDefinedAt(int x)") + decoder.assertDecodeVariable( + "example.A$$anon$1", + "boolean isDefinedAt(int x)", + "int x", + "x: A", + 6 + ) + decoder.assertDecodeVariable( + "example.A$$anon$1", + "boolean isDefinedAt(int x)", + "int z", + "z: Int", + 6 + ) + + decoder.showVariables("example.A$$anon$1", "java.lang.Object applyOrElse(int x, scala.Function1 default)") + decoder.assertDecodeVariable( + "example.A$$anon$1", + "java.lang.Object applyOrElse(int x, scala.Function1 default)", + "scala.Function1 default", + "default: A1 => B1", + 6 + ) + } + + test("inlined this") { + val source = + """|package example + | + |class A(x: Int): + | inline def foo: Int = x + x + | + |class B: + | def bar(a: A) = a.foo + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.B", "int bar(example.A a)") + decoder.assertDecodeVariable( + "example.B", + "int bar(example.A a)", + "example.A A_this", + "this: A.this.type", + 7, + generated = true + ) + } + + test("inlined param") { + val source = + """|package example + | + |class A { + | inline def foo(x: Int): Int = x + x + | + | def bar(y: Int) = foo(y + 2) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int bar(int y)") + decoder.assertDecodeVariable("example.A", "int bar(int y)", "int x$proxy1", "x: Int", 4) + } + + test("bridge parameter") { + val source = + """|package example + |class A + |class B extends A + | + |class C: + | def foo(x: Int): A = new A + | + |class D extends C: + | override def foo(y: Int): B = new B + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.D", "example.A foo(int x)") + decoder.assertIgnoredVariable("example.D", "example.A foo(int x)", "int x", "Bridge") + } + + test("lazy val capture") { + val source = + """|package example + | + |class A { + | def foo = + | val y = 4 + | lazy val z = y + 1 + | def bar = z + | z + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int bar$1(scala.runtime.LazyInt z$lzy1$3, int y$3)") + decoder.assertDecodeVariable( + "example.A", + "int bar$1(scala.runtime.LazyInt z$lzy1$3, int y$3)", + "scala.runtime.LazyInt z$lzy1$3", + "z.: Int", + 7, + generated = true + ) + + decoder.showVariables("example.A", "int z$lzyINIT1$1(scala.runtime.LazyInt z$lzy1$1, int y$1)") + decoder.assertDecodeVariable( + "example.A", + "int z$lzyINIT1$1(scala.runtime.LazyInt z$lzy1$1, int y$1)", + "scala.runtime.LazyInt z$lzy1$1", + "z.: Int", + 7, + generated = true + ) + + decoder.showVariables("example.A", "int z$1(scala.runtime.LazyInt z$lzy1$2, int y$2)") + decoder.assertDecodeVariable( + "example.A", + "int z$1(scala.runtime.LazyInt z$lzy1$2, int y$2)", + "scala.runtime.LazyInt z$lzy1$2", + "z.: Int", + 7, + generated = true + ) + } + + test("by-name arg capture") { + val source = + """|package example + | + |class A { + | def foo(x: => Int) = ??? + | + | def bar(x: Int) = + | foo(x) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int bar$$anonfun$1(int x$1)") + decoder.assertDecodeVariable( + "example.A", + "int bar$$anonfun$1(int x$1)", + "int x$1", + "x.: Int", + 7, + generated = true + ) + } + + test("binds") { + val source = + """|package example + | + |class B + |case class C(x: Int, y: String) extends B + |case class D(z: String) extends B + |case class E(v: Int) extends B + |case class F(w: Int) extends B + | + |class A: + | private def bar(a: B) = + | a match + | case F(w) => w + | case C(x, y) => + | x + | case D(z) => 0 + | case E(v) => 1 + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + // decoder.showVariables("example.A", "int bar(example.B a)") + // decoder.assertDecodeAll( + // ExpectedCount(2), + // ExpectedCount(37), + // expectedFields = ExpectedCount(5) + // ) + decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int w", "w: Int", 12) + decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int x", "x: Int", 14) + + } + + test("mixin and trait static forwarders") { + val source = + """|package example + | + |trait A { + | def foo(x: Int): Int = x + |} + | + |class B extends A + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.B", "int foo(int x)") + decoder.showVariables("example.A", "int foo$(example.A $this, int x)") + decoder.assertDecodeVariable("example.B", "int foo(int x)", "int x", "x: Int", 7) + decoder.assertDecodeVariable( + "example.A", + "int foo$(example.A $this, int x)", + "example.A $this", + "this: A.this.type", + 4, + generated = true + ) + } + + test("this AnyVal") { + val source = + """|package example + | + |class A(x: Int) extends AnyVal { + | def foo: Int = + | x + | + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A$", "int foo$extension(int $this)") + decoder.assertDecodeVariable( + "example.A$", + "int foo$extension(int $this)", + "int $this", + "x: Int", + 5, + generated = true + ) + } + + test("this variable") { + val source = + """|package example + | + |class A: + | def foo: Int = + | 4 + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int foo()") + decoder.assertDecodeVariable("example.A", "int foo()", "example.A this", "this: A.this.type", 5, generated = true) + } + + test("binds tuple and pattern matching") { + val source = + """|package example + | + |class A { + | def foo: Int = + | val x = (1, 2) + | val (c, d) = (3, 4) + | x match + | case (a, b) => a + b + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int foo()") + decoder.assertDecodeVariable("example.A", "int foo()", "scala.Tuple2 x", "x: (Int, Int)", 5) + decoder.assertDecodeVariable("example.A", "int foo()", "int c", "c: Int", 6) + decoder.assertDecodeVariable("example.A", "int foo()", "int d", "d: Int", 6) + decoder.assertDecodeVariable("example.A", "int foo()", "int a", "a: Int", 8) + decoder.assertDecodeVariable("example.A", "int foo()", "int b", "b: Int", 8) + } + + test("ambiguous impossible") { + val source = + """|package example + | + |class A: + | def foo(a: Boolean) = + | if (a) {val x = 1} else {val x = 2} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "void foo(boolean a)") + decoder.assertAmbiguousVariable("example.A", "void foo(boolean a)", "int x", 5) + } + + test("ambiguous variables 2") { + val source = + """|package example + | + |class A: + | def foo() = + | var i = 0 + | while i < 10 do + | val x = i + | i += 1 + | val x = 17 + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "void foo()") + decoder.assertDecodeVariable("example.A", "void foo()", "int x", "x: Int", line = 7) + decoder.assertDecodeVariable("example.A", "void foo()", "int x", "x: Int", line = 9) + } + + test("ambiguous variables") { + val source = + """|package example + | + |class A : + | def foo(a: Boolean) = + | if a then + | val x = 1 + | x + | else + | val x = "2" + | x + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "java.lang.Object foo(boolean a)") + decoder.assertDecodeVariable("example.A", "java.lang.Object foo(boolean a)", "int x", "x: Int", line = 7) + decoder.assertDecodeVariable( + "example.A", + "java.lang.Object foo(boolean a)", + "java.lang.String x", + "x: String", + line = 9 + ) + } + + test("local object") { + val source = + """|package example + | + |class A { + | def foo() = + | object B + | B + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "java.lang.Object foo()") + decoder.assertNoSuchElementVariable("example.A", "java.lang.Object foo()", "example.A$B$2$ B$1") + } + + test("local lazy val") { + val source = + """|package example + | + |class A: + | def foo() = + | lazy val x: Int = 1 + | x + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int foo()") + decoder.assertNoSuchElementVariable("example.A", "int foo()", "int x$1$lzyVal") + } + + test("array") { + val source = + """|package example + | + |class A { + | def foo() = + | val x = Array(1, 2, 3) + | x + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int[] foo()") + decoder.assertDecodeVariable("example.A", "int[] foo()", "int[] x", "x: Array[Int]", 6) + } + + test("captured param in a local def") { + val source = + """|package example + | + |class A { + | def foo(x: Int) = { + | def bar() = x + | } + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int bar$1(int x$1)") + decoder.assertDecodeVariable( + "example.A", + "int bar$1(int x$1)", + "int x$1", + "x.: Int", + 5, + generated = true + ) + } + + test("method parameter") { + val source = + """|package example + | + |class A: + | def foo(y: String) = + | println(y) + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "void foo(java.lang.String y)") + decoder.assertDecodeVariable( + "example.A", + "void foo(java.lang.String y)", + "java.lang.String y", + "y: String", + 5 + ) + } + + test("local variable") { + val source = + """|package example + | + |class A: + | def foo = + | val x: Int = 1 + | x + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int foo()") + decoder.assertDecodeVariable("example.A", "int foo()", "int x", "x: Int", 6) + } + test("capture value class") { val source = """|package example @@ -2248,6 +2701,12 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco "TreeUnpickler.readTpt.()(using Contexts.Context): tpd.Tree", generated = true ) + decoder.assertNotFoundVariable( + "scala.quoted.runtime.impl.QuotesImpl$reflect$defn$", + "dotty.tools.dotc.core.Symbols$Symbol TupleClass(int arity)", + "dotty.tools.dotc.core.Types$TypeRef x$proxy1", + 2816 + ) test("tasty-query#412"): val decoder = initDecoder("dev.zio", "zio-interop-cats_3", "23.1.0.0")(using ThrowOrWarn.ignore) diff --git a/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala b/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala index 7bef8a7..4639075 100644 --- a/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala +++ b/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala @@ -31,6 +31,12 @@ trait BinaryDecoderSuite extends CommonFunSuite: TestingDecoder(library, libraries) extension (decoder: TestingDecoder) + def showVariables(className: String, method: String): Unit = + val variables = loadBinaryMethod(className, method).variables + println( + s"Available binary variables in $method are:\n" + variables.map(f => s" " + formatVariable(f)).mkString("\n") + ) + def showFields(className: String): Unit = val fields = decoder.classLoader.loadClass(className).declaredFields println(s"Available binary fields in $className are:\n" + fields.map(f => s" " + formatField(f)).mkString("\n")) @@ -56,6 +62,42 @@ trait BinaryDecoderSuite extends CommonFunSuite: assertEquals(formatter.format(decodedField), expected) assertEquals(decodedField.isGenerated, generated) + def assertDecodeVariable( + className: String, + method: String, + variable: String, + expected: String, + line: Int, + generated: Boolean = false + )(using + munit.Location + ): Unit = + val binaryVariable = loadBinaryVariable(className, method, variable) + val decodedVariable = decoder.decode(binaryVariable, line) + assertEquals(formatter.format(decodedVariable), expected) + assertEquals(decodedVariable.isGenerated, generated) + + def assertAmbiguousVariable(className: String, method: String, variable: String, line: Int)(using + munit.Location + ): Unit = + val binaryVariable = loadBinaryVariable(className, method, variable) + intercept[AmbiguousException](decoder.decode(binaryVariable, line)) + + def assertNotFoundVariable(className: String, method: String, variable: String, line: Int)(using + munit.Location + ): Unit = + val binaryVariable = loadBinaryVariable(className, method, variable) + intercept[NotFoundException](decoder.decode(binaryVariable, line)) + + def assertNoSuchElementVariable(className: String, method: String, variable: String)(using munit.Location): Unit = + intercept[NoSuchElementException](loadBinaryVariable(className, method, variable)) + + def assertIgnoredVariable(className: String, method: String, variable: String, reason: String)(using + munit.Location + ): Unit = + val binaryVariable = loadBinaryVariable(className, method, variable) + intercept[IgnoredException](decoder.decode(binaryVariable, 0)) + def assertAmbiguousField(className: String, field: String)(using munit.Location): Unit = val binaryField: binary.Field = loadBinaryField(className, field) intercept[AmbiguousException](decoder.decode(binaryField)) @@ -88,40 +130,42 @@ trait BinaryDecoderSuite extends CommonFunSuite: expectedClasses: ExpectedCount = ExpectedCount(0), expectedMethods: ExpectedCount = ExpectedCount(0), expectedFields: ExpectedCount = ExpectedCount(0), + expectedVariables: ExpectedCount = ExpectedCount(0), printProgress: Boolean = false )(using munit.Location): Unit = - val (classCounter, methodCounter, fieldCounter) = decodeAll(printProgress) + val (classCounter, methodCounter, fieldCounter, variableCounter) = decodeAll(printProgress) if classCounter.throwables.nonEmpty then classCounter.printThrowables() classCounter.printThrowable(0) else if methodCounter.throwables.nonEmpty then methodCounter.printThrowables() methodCounter.printThrowable(0) - fieldCounter.printNotFound() + else if variableCounter.throwables.nonEmpty then variableCounter.printThrowable(0) + // variableCounter.printNotFound(40) classCounter.check(expectedClasses) methodCounter.check(expectedMethods) fieldCounter.check(expectedFields) + variableCounter.check(expectedVariables) - def decodeAll(printProgress: Boolean = false): (Counter, Counter, Counter) = + def decodeAll(printProgress: Boolean = false): (Counter, Counter, Counter, Counter) = val classCounter = Counter(decoder.name + " classes") val methodCounter = Counter(decoder.name + " methods") val fieldCounter = Counter(decoder.name + " fields") + val variableCounter = Counter(decoder.name + " variables") for binaryClass <- decoder.allClasses _ = if printProgress then println(s"\"${binaryClass.name}\"") decodedClass <- decoder.tryDecode(binaryClass, classCounter) - binaryMethodOrField <- binaryClass.declaredMethods ++ binaryClass.declaredFields - do - if printProgress then println(formatDebug(binaryMethodOrField)) - binaryMethodOrField match - case m: binary.Method => - decoder.tryDecode(decodedClass, m, methodCounter) - case f: binary.Field => - decoder.tryDecode(decodedClass, f, fieldCounter) + _ = binaryClass.declaredFields.foreach(f => decoder.tryDecode(decodedClass, f, fieldCounter)) + binaryMethod <- binaryClass.declaredMethods + decodedMethod <- decoder.tryDecode(decodedClass, binaryMethod, methodCounter) + binaryVariable <- binaryMethod.variables + do decoder.tryDecode(decodedMethod, binaryVariable, variableCounter) classCounter.printReport() methodCounter.printReport() fieldCounter.printReport() - (classCounter, methodCounter, fieldCounter) + variableCounter.printReport() + (classCounter, methodCounter, fieldCounter, variableCounter) private def loadBinaryMethod(declaringType: String, method: String)(using munit.Location @@ -143,6 +187,19 @@ trait BinaryDecoderSuite extends CommonFunSuite: |""".stripMargin + binaryFields.map(f => s" " + formatField(f)).mkString("\n") binaryFields.find(f => formatField(f) == field).getOrElse(throw new NoSuchElementException(notFoundMessage)) + private def loadBinaryVariable(declaringType: String, method: String, variableName: String)(using + munit.Location + ): binary.Variable = + val binaryMethod = loadBinaryMethod(declaringType, method) + val binaryVariables = binaryMethod.variables + def notFoundMessage: String = + s"""|$variableName + | Available binary variables in $method are: + |""".stripMargin + binaryVariables.map(v => s" " + formatVariable(v)).mkString("\n") + binaryVariables + .find(v => formatVariable(v) == variableName) + .getOrElse(throw new NoSuchElementException(notFoundMessage)) + private def tryDecode(cls: binary.ClassType, counter: Counter): Option[DecodedClass] = try val sym = decoder.decode(cls) @@ -159,15 +216,24 @@ trait BinaryDecoderSuite extends CommonFunSuite: counter.throwables += (cls -> e) None - private def tryDecode(cls: DecodedClass, mthd: binary.Method, counter: Counter): Unit = + private def tryDecode(cls: DecodedClass, mthd: binary.Method, counter: Counter): Option[DecodedMethod] = try val decoded = decoder.decode(cls, mthd) counter.success += (mthd -> decoded) + Some(decoded) catch - case notFound: NotFoundException => counter.notFound += (mthd -> notFound) - case ambiguous: AmbiguousException => counter.ambiguous += ambiguous - case ignored: IgnoredException => counter.ignored += ignored - case e => counter.throwables += (mthd -> e) + case notFound: NotFoundException => + counter.notFound += (mthd -> notFound) + None + case ambiguous: AmbiguousException => + counter.ambiguous += ambiguous + None + case ignored: IgnoredException => + counter.ignored += ignored + None + case e => + counter.throwables += (mthd -> e) + None private def tryDecode(cls: DecodedClass, field: binary.Field, counter: Counter): Unit = try @@ -178,12 +244,27 @@ trait BinaryDecoderSuite extends CommonFunSuite: case ambiguous: AmbiguousException => counter.ambiguous += ambiguous case ignored: IgnoredException => counter.ignored += ignored case e => counter.throwables += (field -> e) + + private def tryDecode(mtd: DecodedMethod, variable: binary.Variable, counter: Counter): Unit = + try + val decoded = decoder.decode(mtd, variable, variable.sourceLines.get.lines.head) + counter.success += (variable -> decoded) + catch + case notFound: NotFoundException => counter.notFound += (variable -> notFound) + case ambiguous: AmbiguousException => counter.ambiguous += ambiguous + case ignored: IgnoredException => counter.ignored += ignored + case e => counter.throwables += (variable -> e) end extension private def formatDebug(m: binary.Symbol): String = m match case f: binary.Field => s"\"${f.declaringClass}\", \"${formatField(f)}\"" case m: binary.Method => s"\"${m.declaringClass.name}\", \"${formatMethod(m)}\"" + case v: binary.Variable => + s"\"${v.declaringMethod.declaringClass.name}\",\n \"${formatMethod( + v.declaringMethod + )}\",\n \"${formatVariable(v)}\",\n \"\",\n ${v.sourceLines.get.lines.head}, " + // s"\"${v.showSpan}" case cls => s"\"${cls.name}\"" private def formatMethod(m: binary.Method): String = @@ -194,6 +275,9 @@ trait BinaryDecoderSuite extends CommonFunSuite: private def formatField(f: binary.Field): String = s"${f.`type`.name} ${f.name}" + private def formatVariable(v: binary.Variable): String = + s"${v.`type`.name} ${v.name}" + case class ExpectedCount(success: Int, ambiguous: Int = 0, notFound: Int = 0, throwables: Int = 0) case class Count(name: String, success: Int = 0, ambiguous: Int = 0, notFound: Int = 0, throwables: Int = 0): @@ -253,8 +337,13 @@ trait BinaryDecoderSuite extends CommonFunSuite: println(s"mean: ${formatted.map((j, s) => s.size - j.size).sum / formatted.size}") end printComparisionWithJavaFormatting + def printSuccess() = + success.foreach { (s, d) => + println(s"${formatDebug(s)}: $d") + } + def printNotFound() = - notFound.foreach { case (s1, NotFoundException(s2)) => + notFound.foreach { case (s1, NotFoundException(s2, _)) => if s1 != s2 then println(s"${formatDebug(s1)} not found because of ${formatDebug(s2)}") else println(s"${formatDebug(s1)} not found") } @@ -264,6 +353,19 @@ trait BinaryDecoderSuite extends CommonFunSuite: println(s"${formatDebug(s)} is ambiguous:" + candidates.map(s"\n - " + _).mkString) } + // print the first n ambiguous symbols + def printAmbiguous(n: Int) = + ambiguous.take(n).foreach { case AmbiguousException(s, candidates) => + println(s"${formatDebug(s)} is ambiguous:" + candidates.map(s"\n - " + _).mkString) + } + + def printNotFound(n: Int) = + notFound.take(n).foreach { case (s1, NotFoundException(s2, owner)) => + if s1 != s2 then println(s"${formatDebug(s1)} not found because of ${formatDebug(s2)}") + else println(s"- ${formatDebug(s1)} not found " + owner.map(o => s"in ${o.getClass.getSimpleName()}")) + println("") + } + def printThrowable(i: Int) = if throwables.size > i then val (sym, t) = throwables(i) @@ -274,6 +376,11 @@ trait BinaryDecoderSuite extends CommonFunSuite: println(s"${formatDebug(sym)} $t") } + def printNThrowables(n: Int) = + throwables.take(n).foreach { (sym, t) => + println(s"${formatDebug(sym)} $t") + } + def check(expected: ExpectedCount)(using munit.Location): Unit = assertEquals(success.size, expected.success) assertEquals(ambiguous.size, expected.ambiguous) diff --git a/src/test/scala/ch/epfl/scala/decoder/testutils/ClasspathEntry.scala b/src/test/scala/ch/epfl/scala/decoder/testutils/ClasspathEntry.scala index 66fcd34..9cd5aaa 100644 --- a/src/test/scala/ch/epfl/scala/decoder/testutils/ClasspathEntry.scala +++ b/src/test/scala/ch/epfl/scala/decoder/testutils/ClasspathEntry.scala @@ -5,3 +5,4 @@ import java.nio.file.Path case class ClasspathEntry(name: String, absolutePath: Path): def toURL: URL = absolutePath.toUri.toURL + def isJar: Boolean = absolutePath.toString.endsWith(".jar") diff --git a/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala b/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala index f3a5098..98f6b4a 100644 --- a/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala +++ b/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala @@ -9,6 +9,7 @@ import java.nio.file.Files import java.nio.file.Path import scala.jdk.CollectionConverters.* import scala.util.Properties +import java.nio.file.FileSystem object TestingDecoder: def javaRuntime = JavaRuntime(Properties.jdkHome).get @@ -43,17 +44,17 @@ class TestingDecoder(mainEntry: ClasspathEntry, val classLoader: BinaryClassLoad decode(binaryClass) def name: String = mainEntry.name def allClasses: Seq[binary.ClassType] = - val classNames = IO - .withinJarFile(mainEntry.absolutePath) { fs => - val classMatcher = fs.getPathMatcher("glob:**.class") - Files - .walk(fs.getPath("/"): Path) - .filter(classMatcher.matches) - .iterator - .asScala - .map(_.toString.stripPrefix("/").stripSuffix(".class").replace('/', '.')) - .filterNot(_.endsWith("module-info")) - .toSeq - } - .get + def listClassNames(root: Path): Seq[String] = + val classMatcher = root.getFileSystem.getPathMatcher("glob:**.class") + Files + .walk(root) + .filter(classMatcher.matches) + .iterator + .asScala + .map(path => root.relativize(path).toString.stripPrefix("/").stripSuffix(".class").replace('/', '.')) + .filterNot(_.endsWith("module-info")) + .toSeq + val classNames = + if mainEntry.isJar then IO.withinJarFile(mainEntry.absolutePath)(fs => listClassNames(fs.getPath("/"))).get + else listClassNames(mainEntry.absolutePath) classNames.map(classLoader.loadClass)