diff --git a/.scalafmt.conf b/.scalafmt.conf index dc8d4ba..9acc72f 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = "3.8.1" +version = "3.8.4" project.git = true align.preset = none align.stripMargin = true diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 59fa764..f248ab8 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -1,16 +1,16 @@ import sbt._ object Dependencies { - val scala3Next = "3.4.2" - val asmVersion = "9.7" - val coursierVersion = "2.1.10" + val scala3Next = "3.6.2" + val asmVersion = "9.7.1" + val coursierVersion = "2.1.24" - val tastyQuery = "ch.epfl.scala" %% "tasty-query" % "1.3.0" + val tastyQuery = "ch.epfl.scala" %% "tasty-query" % "1.4.0" val asm = "org.ow2.asm" % "asm" % asmVersion val asmUtil = "org.ow2.asm" % "asm-util" % asmVersion // test dependencies - val munit = "org.scalameta" %% "munit" % "1.0.0" + val munit = "org.scalameta" %% "munit" % "1.0.4" val coursier = ("io.get-coursier" %% "coursier" % coursierVersion).cross(CrossVersion.for3Use2_13) val coursierJvm = ("io.get-coursier" %% "coursier-jvm" % coursierVersion).cross(CrossVersion.for3Use2_13) } diff --git a/project/build.properties b/project/build.properties index b485f62..0a832a2 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.0 \ No newline at end of file +sbt.version=1.10.7 \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt index 3eb12e7..935201d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,2 +1,2 @@ -addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12") -addSbtPlugin("org.scala-debugger" % "sbt-jdi-tools" % "1.1.1") +addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.9.2") +addSbtPlugin("com.github.sbt" % "sbt-jdi-tools" % "1.2.0") diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryClassDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryClassDecoder.scala new file mode 100644 index 0000000..0a3240c --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryClassDecoder.scala @@ -0,0 +1,172 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.internal.* +import tastyquery.Contexts.* +import tastyquery.Names.* +import tastyquery.Symbols.* + +import scala.util.matching.Regex + +trait BinaryClassDecoder(using Context, ThrowOrWarn): + self: BinaryDecoder => + + protected val scoper = Scoper() + + def decode(cls: binary.BinaryClass): DecodedClass = + val javaParts = cls.name.split('.') + val packageNames = javaParts.dropRight(1).toList.map(SimpleName.apply) + val packageSym = + if packageNames.nonEmpty + then ctx.findSymbolFromRoot(packageNames).asInstanceOf[PackageSymbol] + else defn.EmptyPackage + val decodedClassName = NameTransformer.decode(javaParts.last) + val allSymbols = decodedClassName match + case Patterns.AnonClass(declaringClassName, remaining) => + val WithLocalPart = "(.+)\\$(.+)\\$\\d+".r + val topLevelClassName = declaringClassName match + case WithLocalPart(topLevelClassName, _) => topLevelClassName.stripSuffix("$") + case topLevelClassName => topLevelClassName + reduceAmbiguityOnClasses(decodeLocalClasses(cls, packageSym, topLevelClassName, "$anon", remaining)) + case Patterns.LocalClass(declaringClassName, localClassName, remaining) => + decodeLocalClasses(cls, packageSym, declaringClassName, localClassName, remaining) + case _ => decodeClassFromPackage(packageSym, decodedClassName) + + val candidates = + if cls.isObject then allSymbols.filter(_.isModuleClass) + else if cls.sourceLines.forall(_.isEmpty) && allSymbols.forall(_.isModuleClass) then + allSymbols.collect { case cls: DecodedClass.ClassDef => DecodedClass.SyntheticCompanionClass(cls.symbol) } + else allSymbols.filter(!_.isModuleClass) + candidates.singleOrThrow(cls) + end decode + + private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] = + if syms.size > 1 then + val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _))) + if reduced.size != 0 then reduced else syms + else syms + + private def decodeLocalClasses( + javaClass: binary.BinaryClass, + packageSym: PackageSymbol, + declaringClassName: String, + localClassName: String, + remaining: Option[String] + ): Seq[DecodedClass] = + val classOwners = decodeClassFromPackage(packageSym, declaringClassName).map(_.symbol) + remaining match + case None => + val parents = (javaClass.superclass.toSet ++ javaClass.interfaces) + .map(decode) + .collect { case cls: DecodedClass.ClassDef => cls.symbol } + classOwners + .flatMap(cls => collectLocalClasses(cls, localClassName, javaClass.sourceLines)) + .filter(matchParents(_, parents, javaClass.isInterface)) + case Some(remaining) => + val localClasses = classOwners + .flatMap(cls => collectLocalClasses(cls, localClassName, None)) + .flatMap(_.classSymbol) + localClasses.flatMap(s => decodeClassRecursively(s, remaining)) + + private def decodeClassFromPackage(owner: PackageSymbol, decodedName: String): Seq[DecodedClass.ClassDef] = + val packageObject = "([^\\$]+\\$package)(\\$.*)?".r + val specializedClass = "([^\\$]+\\$mc.+\\$sp)(\\$.*)?".r + val standardClass = "([^\\$]+)(\\$.*)?".r + val topLevelName = decodedName match + case packageObject(name, _) => name + case specializedClass(name, _) => name + case standardClass(name, _) => name + val remaining = decodedName.stripPrefix(topLevelName).stripPrefix("$") + val typeNames = Seq(typeName(topLevelName), moduleClassName(topLevelName)) + typeNames + .flatMap(owner.getDecl) + .collect { case sym: ClassSymbol => sym } + .flatMap { sym => + if remaining.isEmpty then Seq(DecodedClass.ClassDef(sym)) + else decodeClassRecursively(sym, remaining) + } + + private def enclose(enclosing: DecodedClass, enclosed: DecodedClass): Boolean = + (enclosing, enclosed) match + case (enclosing: DecodedClass.InlinedClass, enclosed: DecodedClass.InlinedClass) => + enclosing.callPos.enclose(enclosed.callPos) || ( + !enclosed.callPos.enclose(enclosing.callPos) && + enclose(enclosing.underlying, enclosed.underlying) + ) + case (enclosing: DecodedClass.InlinedClass, enclosed) => + enclosing.callPos.enclose(enclosed.pos) + case (enclosing, enclosed: DecodedClass.InlinedClass) => + enclosing.pos.enclose(enclosed.callPos) + case (enclosing, enclosed) => + enclosing.pos.enclose(enclosed.pos) + + private def collectLocalClasses( + classSymbol: ClassSymbol, + name: String, + sourceLines: Option[binary.SourceLines] + ): Seq[DecodedClass] = + val localClasses = collectLiftedTrees(classSymbol, sourceLines) { + case cls: LocalClass if cls.symbol.sourceName == name => cls + } + .map(cls => wrapIfInline(cls, DecodedClass.ClassDef(cls.symbol))) + val samAndPartialFunctions = collectLiftedTrees(classSymbol, sourceLines) { case lambda: LambdaTree => lambda } + .map { lambda => + val (term, samClass) = lambda.symbol + wrapIfInline(lambda, DecodedClass.SAMOrPartialFunction(term, samClass, lambda.tpe.asInstanceOf)) + } + localClasses ++ samAndPartialFunctions + + private def matchParents( + decodedClass: DecodedClass, + expectedParents: Set[ClassSymbol], + isInterface: Boolean + ): Boolean = + decodedClass match + case cls: DecodedClass.ClassDef => + if cls.symbol.isEnum then expectedParents == cls.symbol.parentClasses.toSet + defn.ProductClass + else if isInterface then expectedParents == cls.symbol.parentClasses.filter(_.isTrait).toSet + else if cls.symbol.isAnonClass then cls.symbol.parentClasses.forall(expectedParents.contains) + else expectedParents == cls.symbol.parentClasses.toSet + case _: DecodedClass.SyntheticCompanionClass => false + case anonFun: DecodedClass.SAMOrPartialFunction => + if anonFun.parentClass == Definitions.PartialFunctionClass then + expectedParents == Set(Definitions.AbstractPartialFunctionClass, Definitions.SerializableClass) + else expectedParents.contains(anonFun.parentClass) + case inlined: DecodedClass.InlinedClass => matchParents(inlined.underlying, expectedParents, isInterface) + + private def decodeClassRecursively(owner: ClassSymbol, decodedName: String): Seq[DecodedClass.ClassDef] = + owner.declarations + .collect { case sym: ClassSymbol => sym } + .flatMap { sym => + val Symbol = s"${Regex.quote(sym.sourceName)}\\$$?(.*)".r + decodedName match + case Symbol(remaining) => + if remaining.isEmpty then Some(DecodedClass.ClassDef(sym)) + else decodeClassRecursively(sym, remaining) + case _ => None + } + + protected def collectLiftedTrees[S](owner: Symbol, sourceLines: Option[binary.SourceLines])( + matcher: PartialFunction[LiftedTree[?], LiftedTree[S]] + ): Seq[LiftedTree[S]] = + val recursiveMatcher = new PartialFunction[LiftedTree[?], LiftedTree[S]]: + override def apply(tree: LiftedTree[?]): LiftedTree[S] = tree.asInstanceOf[LiftedTree[S]] + override def isDefinedAt(tree: LiftedTree[?]): Boolean = tree match + case InlinedFromArg(underlying, _, _) => isDefinedAt(underlying) + case InlinedFromDef(underlying, _) => isDefinedAt(underlying) + case _ => matcher.isDefinedAt(tree) + collectAllLiftedTrees(owner).collect(recursiveMatcher).filter(tree => sourceLines.forall(matchLines(tree, _))) + + protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = + LiftedTreeCollector.collect(owner) + + private def wrapIfInline(liftedTree: LiftedTree[?], decodedClass: DecodedClass): DecodedClass = + liftedTree match + case InlinedFromDef(underlying, inlineCall) => + DecodedClass.InlinedClass(wrapIfInline(underlying, decodedClass), inlineCall.callTree) + case _ => decodedClass + + private def matchLines(liftedFun: LiftedTree[?], sourceLines: binary.SourceLines): Boolean = + // we use endsWith instead of == because of tasty-query#434 + val positions = + liftedFun.scope(scoper).allPositions.filter(pos => pos.sourceFile.name.endsWith(sourceLines.sourceName)) + sourceLines.tastyLines.forall(line => positions.exists(_.containsLine(line))) diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala index 3b1eeae..1031eea 100644 --- a/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala @@ -2,20 +2,13 @@ package ch.epfl.scala.decoder import ch.epfl.scala.decoder.binary import ch.epfl.scala.decoder.internal.* -import ch.epfl.scala.decoder.javareflect.JavaReflectLoader import tastyquery.Contexts.* -import tastyquery.Names.* -import tastyquery.Signatures.* -import tastyquery.SourcePosition -import tastyquery.Symbols.* -import tastyquery.Trees.* -import tastyquery.Types.* +import tastyquery.Symbols.Symbol import tastyquery.jdk.ClasspathLoaders import java.nio.file.Path -import scala.util.matching.Regex -import tastyquery.Exceptions.NonMethodReferenceException -import tastyquery.SourceLanguage +import scala.collection.concurrent.TrieMap +import scala.util.Try object BinaryDecoder: def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder = @@ -26,1183 +19,25 @@ object BinaryDecoder: def cached(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder = val classpath = CustomClasspath(ClasspathLoaders.read(classEntries.toList)) val ctx = Context.initialize(classpath) - new CachedBinaryDecoder(using ctx) - -class BinaryDecoder(using Context, ThrowOrWarn): - private given defn: Definitions = Definitions() - - def decode(cls: binary.ClassType): DecodedClass = - val javaParts = cls.name.split('.') - val packageNames = javaParts.dropRight(1).toList.map(SimpleName.apply) - val packageSym = - if packageNames.nonEmpty - then ctx.findSymbolFromRoot(packageNames).asInstanceOf[PackageSymbol] - else defn.EmptyPackage - val decodedClassName = NameTransformer.decode(javaParts.last) - val allSymbols = decodedClassName match - case Patterns.AnonClass(declaringClassName, remaining) => - val WithLocalPart = "(.+)\\$(.+)\\$\\d+".r - val topLevelClassName = declaringClassName match - case WithLocalPart(topLevelClassName, _) => topLevelClassName.stripSuffix("$") - case topLevelClassName => topLevelClassName - reduceAmbiguityOnClasses(decodeLocalClasses(cls, packageSym, topLevelClassName, "$anon", remaining)) - case Patterns.LocalClass(declaringClassName, localClassName, remaining) => - decodeLocalClasses(cls, packageSym, declaringClassName, localClassName, remaining) - case _ => decodeClassFromPackage(packageSym, decodedClassName) - - val candidates = - if cls.isObject then allSymbols.filter(_.isModuleClass) - else if cls.sourceLines.forall(_.isEmpty) && allSymbols.forall(_.isModuleClass) then - allSymbols.collect { case cls: DecodedClass.ClassDef => DecodedClass.SyntheticCompanionClass(cls.symbol) } - else allSymbols.filter(!_.isModuleClass) - candidates.singleOrThrow(cls) - end decode - - def decode(method: binary.Method): DecodedMethod = - val decodedClass = decode(method.declaringClass) - decode(decodedClass, method) - - def decode(decodedClass: DecodedClass, method: binary.Method): DecodedMethod = - def tryDecode(f: PartialFunction[binary.Method, Seq[DecodedMethod]]): Seq[DecodedMethod] = - f.applyOrElse(method, _ => Seq.empty[DecodedMethod]) - - extension (xs: Seq[DecodedMethod]) - def orTryDecode(f: PartialFunction[binary.Method, Seq[DecodedMethod]]): Seq[DecodedMethod] = - if xs.nonEmpty then xs else f.applyOrElse(method, _ => Seq.empty[DecodedMethod]) - val candidates = - tryDecode { - // static and/or bridge - case Patterns.AdaptedAnonFun() => decodeAdaptedAnonFun(decodedClass, method) - // bridge or standard - case Patterns.SpecializedMethod(names) => decodeSpecializedMethod(decodedClass, method, names) - // bridge only - case m if m.isBridge => decodeBridgesAndMixinForwarders(decodedClass, method).toSeq - // static or standard - case Patterns.AnonFun() => decodeAnonFunsAndReduceAmbiguity(decodedClass, method) - case Patterns.ByNameArgProxy() => decodeByNameArgsProxy(decodedClass, method) - case Patterns.SuperArg() => decodeSuperArgs(decodedClass, method) - case Patterns.LiftedTree() => decodeLiftedTries(decodedClass, method) - case Patterns.LocalLazyInit(names) => decodeLocalLazyInit(decodedClass, method, names) - // static only - case Patterns.TraitInitializer() => decodeTraitInitializer(decodedClass, method) - case Patterns.DeserializeLambda() => - Seq(DecodedMethod.DeserializeLambda(decodedClass, defn.DeserializeLambdaType)) - case Patterns.TraitStaticForwarder() => decodeTraitStaticForwarder(decodedClass, method).toSeq - case m if m.isStatic && decodedClass.isJava => decodeStaticJavaMethods(decodedClass, method) - // cannot be static anymore - case Patterns.LazyInit(name) => decodeLazyInit(decodedClass, name) - case Patterns.Outer() => decodeOuter(decodedClass).toSeq - case Patterns.ParamForwarder(names) => decodeParamForwarder(decodedClass, method, names) - case Patterns.TraitSetter(name) => decodeTraitSetter(decodedClass, method, name) - case Patterns.Setter(names) => - decodeStandardMethods(decodedClass, method).orIfEmpty(decodeSetter(decodedClass, method, names)) - case Patterns.SuperAccessor(names) => decodeSuperAccessor(decodedClass, method, names) - } - .orTryDecode { case Patterns.ValueClassExtension() => decodeValueClassExtension(decodedClass, method) } - .orTryDecode { case Patterns.InlineAccessor(names) => decodeInlineAccessor(decodedClass, method, names).toSeq } - .orTryDecode { case Patterns.LocalMethod(names) => decodeLocalMethods(decodedClass, method, names) } - .orTryDecode { - case m if m.isStatic => decodeStaticForwarder(decodedClass, method) - case _ => decodeStandardMethods(decodedClass, method) - } - - candidates.singleOrThrow(method) - end decode - - def decode(field: binary.Field): DecodedField = - val decodedClass = decode(field.declaringClass) - decode(decodedClass, field) - - def decode(decodedClass: DecodedClass, field: binary.Field): DecodedField = - def tryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] = - f.applyOrElse(field, _ => Seq.empty[DecodedField]) - - extension (xs: Seq[DecodedField]) - def orTryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] = - if xs.nonEmpty then xs else f.applyOrElse(field, _ => Seq.empty[DecodedField]) - val decodedFields = - tryDecode { - case Patterns.LazyVal(name) => - for - owner <- decodedClass.classSymbol.toSeq ++ decodedClass.linearization.filter(_.isTrait) - sym <- owner.declarations.collect { - case sym: TermSymbol if sym.nameStr == name && sym.isModuleOrLazyVal => sym - } - yield DecodedField.ValDef(decodedClass, sym) - case Patterns.Module() => - decodedClass.classSymbol.flatMap(_.moduleValue).map(DecodedField.ModuleVal(decodedClass, _)).toSeq - case Patterns.Offset(nbr) => - Seq(DecodedField.LazyValOffset(decodedClass, nbr, defn.LongType)) - case Patterns.OuterField() => - decodedClass.symbolOpt - .flatMap(_.outerClass) - .map(outerClass => DecodedField.Outer(decodedClass, outerClass.selfType)) - .toSeq - case Patterns.SerialVersionUID() => - Seq(DecodedField.SerialVersionUID(decodedClass, defn.LongType)) - case Patterns.LazyValBitmap(name) => - Seq(DecodedField.LazyValBitmap(decodedClass, defn.BooleanType, name)) - case Patterns.AnyValCapture() => - for - classSym <- decodedClass.symbolOpt.toSeq - outerClass <- classSym.outerClass.toSeq - if outerClass.isSubClass(defn.AnyValClass) - sym <- outerClass.declarations.collect { - case sym: TermSymbol if sym.isVal && !sym.isMethod => sym - } - yield DecodedField.Capture(decodedClass, sym) - case Patterns.Capture(names) => - decodedClass.treeOpt.toSeq - .flatMap(CaptureCollector.collectCaptures) - .filter { captureSym => - names.exists { - case Patterns.LazyVal(name) => name == captureSym.nameStr - case name => name == captureSym.nameStr - } - } - .map(DecodedField.Capture(decodedClass, _)) - - case _ if field.isStatic && decodedClass.isJava => - for - owner <- decodedClass.companionClassSymbol.toSeq - sym <- owner.declarations.collect { case sym: TermSymbol if sym.nameStr == field.name => sym } - yield DecodedField.ValDef(decodedClass, sym) - }.orTryDecode { case _ => - for - owner <- withCompanionIfExtendsJavaLangEnum(decodedClass) ++ decodedClass.linearization.filter(_.isTrait) - sym <- owner.declarations.collect { - case sym: TermSymbol if matchTargetName(field, sym) && !sym.isMethod => sym - } - yield DecodedField.ValDef(decodedClass, sym) - } - decodedFields.singleOrThrow(field) - end decode - - def decode(variable: binary.Variable, sourceLine: Int): DecodedVariable = - val decodedMethod = decode(variable.declaringMethod) - decode(decodedMethod, variable, sourceLine) - - def decode(decodedMethod: DecodedMethod, variable: binary.Variable, sourceLine: Int): DecodedVariable = - def tryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = - f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) - - extension (xs: Seq[DecodedVariable]) - def orTryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = - if xs.nonEmpty then xs else f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) - val decodedVariables = tryDecode { - case Patterns.CapturedLzyVariable(name) => decodeCapturedLzyVariable(decodedMethod, name) - case Patterns.CapturedTailLocalVariable(name) => decodeCapturedVariable(decodedMethod, name) - case Patterns.CapturedVariable(name) => decodeCapturedVariable(decodedMethod, name) - case Patterns.This() => decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) - case Patterns.DollarThis() => decodeDollarThis(decodedMethod) - case Patterns.Proxy(name) => decodeProxy(decodedMethod, name) - case Patterns.InlinedThis() => decodeInlinedThis(decodedMethod, variable) - }.orTryDecode { case _ => - decodedMethod match - case decodedMethod: DecodedMethod.SAMOrPartialFunctionImpl => - decodeValDef(decodedMethod, variable, sourceLine) - .orIfEmpty( - decodeSAMOrPartialFun( - decodedMethod, - decodedMethod.implementedSymbol, - variable, - sourceLine - ) - ) - case _: DecodedMethod.Bridge => ignore(variable, "Bridge method") - case _ => decodeValDef(decodedMethod, variable, sourceLine) - } - decodedVariables.singleOrThrow(variable, decodedMethod) - - private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] = - if syms.size > 1 then - val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _))) - if reduced.size != 0 then reduced else syms - else syms - - private def enclose(enclosing: DecodedClass, enclosed: DecodedClass): Boolean = - (enclosing, enclosed) match - case (enclosing: DecodedClass.InlinedClass, enclosed: DecodedClass.InlinedClass) => - enclosing.callPos.enclose(enclosed.callPos) || ( - !enclosed.callPos.enclose(enclosing.callPos) && - enclose(enclosing.underlying, enclosed.underlying) - ) - case (enclosing: DecodedClass.InlinedClass, enclosed) => - enclosing.callPos.enclose(enclosed.pos) - case (enclosing, enclosed: DecodedClass.InlinedClass) => - enclosing.pos.enclose(enclosed.callPos) - case (enclosing, enclosed) => - enclosing.pos.enclose(enclosed.pos) - - private def decodeLocalClasses( - javaClass: binary.ClassType, - packageSym: PackageSymbol, - declaringClassName: String, - localClassName: String, - remaining: Option[String] - ): Seq[DecodedClass] = - val classOwners = decodeClassFromPackage(packageSym, declaringClassName).map(_.symbol) - remaining match - case None => - val parents = (javaClass.superclass.toSet ++ javaClass.interfaces) - .map(decode) - .collect { case cls: DecodedClass.ClassDef => cls.symbol } - classOwners - .flatMap(cls => collectLocalClasses(cls, localClassName, javaClass.sourceLines)) - .filter(matchParents(_, parents, javaClass.isInterface)) - case Some(remaining) => - val localClasses = classOwners - .flatMap(cls => collectLocalClasses(cls, localClassName, None)) - .flatMap(_.classSymbol) - localClasses.flatMap(s => decodeClassRecursively(s, remaining)) - - private def decodeClassFromPackage(owner: PackageSymbol, decodedName: String): Seq[DecodedClass.ClassDef] = - val packageObject = "([^\\$]+\\$package)(\\$.*)?".r - val specializedClass = "([^\\$]+\\$mc.+\\$sp)(\\$.*)?".r - val standardClass = "([^\\$]+)(\\$.*)?".r - val topLevelName = decodedName match - case packageObject(name, _) => name - case specializedClass(name, _) => name - case standardClass(name, _) => name - val remaining = decodedName.stripPrefix(topLevelName).stripPrefix("$") - val typeNames = Seq(typeName(topLevelName), moduleClassName(topLevelName)) - typeNames - .flatMap(owner.getDecl) - .collect { case sym: ClassSymbol => sym } - .flatMap { sym => - if remaining.isEmpty then Seq(DecodedClass.ClassDef(sym)) - else decodeClassRecursively(sym, remaining) - } - - private def decodeClassRecursively(owner: ClassSymbol, decodedName: String): Seq[DecodedClass.ClassDef] = - owner.declarations - .collect { case sym: ClassSymbol => sym } - .flatMap { sym => - val Symbol = s"${Regex.quote(sym.sourceName)}\\$$?(.*)".r - decodedName match - case Symbol(remaining) => - if remaining.isEmpty then Some(DecodedClass.ClassDef(sym)) - else decodeClassRecursively(sym, remaining) - case _ => None - } - - private def collectLocalClasses( - classSymbol: ClassSymbol, - name: String, - sourceLines: Option[binary.SourceLines] - ): Seq[DecodedClass] = - val localClasses = collectLiftedTrees(classSymbol, sourceLines) { - case cls: LocalClass if cls.symbol.sourceName == name => cls - } - .map(cls => wrapIfInline(cls, DecodedClass.ClassDef(cls.symbol))) - val samAndPartialFunctions = collectLiftedTrees(classSymbol, sourceLines) { case lambda: LambdaTree => lambda } - .map { lambda => - val (term, samClass) = lambda.symbol - wrapIfInline(lambda, DecodedClass.SAMOrPartialFunction(term, samClass, lambda.tpe.asInstanceOf)) - } - localClasses ++ samAndPartialFunctions - - private def matchParents( - decodedClass: DecodedClass, - expectedParents: Set[ClassSymbol], - isInterface: Boolean - ): Boolean = - decodedClass match - case cls: DecodedClass.ClassDef => - if cls.symbol.isEnum then expectedParents == cls.symbol.parentClasses.toSet + defn.ProductClass - else if isInterface then expectedParents == cls.symbol.parentClasses.filter(_.isTrait).toSet - else if cls.symbol.isAnonClass then cls.symbol.parentClasses.forall(expectedParents.contains) - else expectedParents == cls.symbol.parentClasses.toSet - case _: DecodedClass.SyntheticCompanionClass => false - case anonFun: DecodedClass.SAMOrPartialFunction => - if anonFun.parentClass == defn.PartialFunctionClass then - expectedParents == Set(defn.AbstractPartialFunctionClass, defn.SerializableClass) - else expectedParents.contains(anonFun.parentClass) - case inlined: DecodedClass.InlinedClass => matchParents(inlined.underlying, expectedParents, isInterface) - - private def matchParents(classSymbol: ClassSymbol, expectedParents: Set[ClassSymbol], isInterface: Boolean): Boolean = - if classSymbol.isEnum then expectedParents == classSymbol.parentClasses.toSet + defn.ProductClass - else if isInterface then expectedParents == classSymbol.parentClasses.filter(_.isTrait).toSet - else if classSymbol.isAnonClass then classSymbol.parentClasses.forall(expectedParents.contains) - else expectedParents == classSymbol.parentClasses.toSet - - private def wrapIfInline(liftedTree: LiftedTree[?], decodedClass: DecodedClass): DecodedClass = - liftedTree match - case InlinedFromDef(underlying, inlineCall) => - DecodedClass.InlinedClass(wrapIfInline(underlying, decodedClass), inlineCall.callTree) - case _ => decodedClass - - private def decodeStaticJavaMethods(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - decodedClass.companionClassSymbol.toSeq - .flatMap(_.declarations) - .collect { - case sym: TermSymbol - if matchTargetName(method, sym) && matchSignature(method, sym.declaredType, checkParamNames = false) => - DecodedMethod.ValOrDefDef(decodedClass, sym) - } - - private def decodeStandardMethods(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - def rec(underlying: DecodedClass): Seq[DecodedMethod] = - underlying match - case anonFun: DecodedClass.SAMOrPartialFunction => - if method.isConstructor then Seq(DecodedMethod.SAMOrPartialFunctionConstructor(decodedClass, anonFun.tpe)) - else if anonFun.parentClass == defn.PartialFunctionClass then - decodePartialFunctionImpl(decodedClass, anonFun.tpe, method).toSeq - else decodeSAMFunctionImpl(decodedClass, anonFun.symbol, anonFun.parentClass, method).toSeq - case underlying: DecodedClass.ClassDef => decodeInstanceMethods(decodedClass, underlying.symbol, method) - case _: DecodedClass.SyntheticCompanionClass => Seq.empty - case inlined: DecodedClass.InlinedClass => rec(inlined.underlying) - rec(decodedClass) - - private def decodeParamForwarder( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod.ValOrDefDef] = - decodedClass.declarations.collect { - case sym: TermSymbol if names.contains(sym.targetNameStr) && matchSignature(method, sym.declaredType) => - DecodedMethod.ValOrDefDef(decodedClass, sym) - } - - private def decodeTraitSetter( - decodedClass: DecodedClass, - method: binary.Method, - name: String - ): Seq[DecodedMethod.SetterAccessor] = - for - traitSym <- decodedClass.linearization.filter(_.isTrait) - if method.decodedName.contains("$" + traitSym.nameStr + "$") - sym <- traitSym.declarations.collect { - case sym: TermSymbol if sym.targetNameStr == name && !sym.isMethod && !sym.isAbstractMember => sym - } - paramType <- decodedClass.thisType.map(sym.typeAsSeenFrom).collect { case tpe: Type => tpe } - yield - val tpe = MethodType(List(SimpleName("x$1")), List(paramType), defn.UnitType) - DecodedMethod.SetterAccessor(decodedClass, sym, tpe) - - private def decodeSetter( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod.SetterAccessor] = - for - param <- method.allParameters.lastOption.toSeq - sym <- decodeFields(decodedClass, param.`type`, names) - yield - val tpe = MethodType(List(SimpleName("x$1")), List(sym.declaredType.asInstanceOf[Type]), defn.UnitType) - DecodedMethod.SetterAccessor(decodedClass, sym, tpe) - - private def decodeFields( - decodedClass: DecodedClass, - binaryType: binary.Type, - names: Seq[String] - ): Seq[TermSymbol] = - def matchType0(sym: TermSymbol): Boolean = matchSetterArgType(sym.declaredType, binaryType) - decodedClass.declarations.collect { - case sym: TermSymbol if !sym.isMethod && names.exists(sym.targetNameStr == _) && matchType0(sym) => - sym - } - - private def decodeSuperAccessor( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod] = - for - traitSym <- decodedClass.linearization.filter(_.isTrait) - if method.decodedName.contains("$" + traitSym.nameStr + "$") - sym <- traitSym.declarations.collect { - case sym: TermSymbol if names.contains(sym.targetNameStr) && !sym.isAbstractMember => sym - } - expectedTpe <- decodedClass.thisType.map(sym.typeAsSeenFrom(_)) - if matchSignature(method, expectedTpe) - yield DecodedMethod.SuperAccessor(decodedClass, sym, expectedTpe) - - private def decodeSpecializedMethod( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod.SpecializedMethod] = - decodedClass.declarations.collect { - case sym: TermSymbol - if names.contains(sym.targetNameStr) && - matchSignature( - method, - sym.declaredType, - captureAllowed = false, - checkParamNames = false, - checkTypeErasure = false - ) && - // hack: in Scala 3 only overriding symbols can be specialized (Function and Tuple) - sym.allOverriddenSymbols.nonEmpty => - DecodedMethod.SpecializedMethod(decodedClass, sym) - } - - private def decodeInlineAccessor( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod] = - val classLoader = method.declaringClass.classLoader - val methodAccessors = method.instructions - .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => - classLoader.loadClass(owner).method(name, descriptor) - } - .singleOpt - .flatten - .map { binaryTarget => - val target = decode(binaryTarget) - // val tpe = target.declaredType.asSeenFrom(fromType, fromClass) - DecodedMethod.InlineAccessor(decodedClass, target) - } - def singleFieldInstruction(f: binary.Instruction.Field => Boolean) = method.instructions - .collect { case instr: binary.Instruction.Field => instr } - .singleOpt - .filter(f) - .toSeq - def fieldSetters = - val expectedNames = names.map(_.stripSuffix("_=")).distinct - for - instr <- singleFieldInstruction(f => f.isPut && f.unexpandedDecodedNames.exists(expectedNames.contains)) - binaryField <- classLoader.loadClass(instr.owner).declaredField(instr.name).toSeq - fieldOwner = decode(binaryField.declaringClass) - sym <- decodeFields(fieldOwner, binaryField.`type`, instr.unexpandedDecodedNames) - yield - val tpe = MethodType(List(SimpleName("x$1")), List(sym.declaredType.asInstanceOf[Type]), defn.UnitType) - val decodedTarget = DecodedMethod.SetterAccessor(fieldOwner, sym, tpe) - DecodedMethod.InlineAccessor(decodedClass, decodedTarget) - def fieldGetters = - for - instr <- singleFieldInstruction(f => !f.isPut && f.unexpandedDecodedNames.exists(names.contains)) - binaryField <- classLoader.loadClass(instr.owner).declaredField(instr.name).toSeq - fieldOwner = decode(binaryField.declaringClass) - sym <- decodeFields(fieldOwner, binaryField.`type`, instr.unexpandedDecodedNames) - yield DecodedMethod.InlineAccessor(decodedClass, DecodedMethod.ValOrDefDef(fieldOwner, sym)) - def moduleAccessors = - for - instr <- singleFieldInstruction(_.name == "MODULE$") - targetClass = decode(classLoader.loadClass(instr.owner)) - targetClassSym <- targetClass.classSymbol - targetTermSym <- targetClassSym.moduleValue - yield DecodedMethod.InlineAccessor(decodedClass, DecodedMethod.ValOrDefDef(targetClass, targetTermSym)) - def valueClassAccessors = - if method.instructions.isEmpty && method.isExtensionMethod then - for - companionClass <- decodedClass.companionClass.toSeq - param <- method.allParameters.lastOption.toSeq - sym <- decodeFields(companionClass, param.`type`, names.map(_.stripSuffix("$extension"))) - yield - val decodedTarget = DecodedMethod.ValOrDefDef(decodedClass, sym) - DecodedMethod.InlineAccessor(decodedClass, decodedTarget) - else Seq.empty - methodAccessors.toSeq - .orIfEmpty(fieldSetters) - .orIfEmpty(fieldGetters) - .orIfEmpty(moduleAccessors.toSeq) - .orIfEmpty(valueClassAccessors) - - private def decodeInstanceMethods( - decodedClass: DecodedClass, - classSymbol: ClassSymbol, - method: binary.Method - ): Seq[DecodedMethod] = - if method.isConstructor && classSymbol.isSubClass(defn.AnyValClass) then - classSymbol.getAllOverloadedDecls(SimpleName("")).map(DecodedMethod.ValOrDefDef(decodedClass, _)) - else - val isJava = decodedClass.isJava - val fromClass = classSymbol.declarations - .collect { case sym: TermSymbol if matchTargetName(method, sym) => sym } - .collect { - case sym - if matchSignature( - method, - sym.declaredType, - asJavaVarargs = isJava, - captureAllowed = !isJava, - checkParamNames = !isJava - ) => - DecodedMethod.ValOrDefDef(decodedClass, sym) - case sym if !isJava && matchSignature(method, sym.declaredType, asJavaVarargs = true) => - DecodedMethod.Bridge(DecodedMethod.ValOrDefDef(decodedClass, sym), sym.declaredType) - } - fromClass.orIfEmpty(decodeAccessorsFromTraits(decodedClass, classSymbol, method)) - - private def decodeAccessorsFromTraits( - decodedClass: DecodedClass, - classSymbol: ClassSymbol, - method: binary.Method - ): Seq[DecodedMethod] = - if classSymbol.isTrait then Seq.empty - else decodeAccessorsFromTraits(decodedClass, classSymbol, classSymbol.thisType, method) - - private def decodeAccessorsFromTraits( - decodedClass: DecodedClass, - fromClass: ClassSymbol, - fromType: Type, - method: binary.Method - ): Seq[DecodedMethod] = - for - traitSym <- fromClass.linearization.filter(_.isTrait) - if !method.isExpanded || method.decodedName.contains("$" + traitSym.nameStr + "$") - sym <- traitSym.declarations - .collect { - case sym: TermSymbol if matchTargetName(method, sym) && matchSignature(method, sym.declaredType) => sym - } - if method.isExpanded == sym.isPrivate - if sym.isParamAccessor || sym.isSetter || !sym.isMethod - if sym.isOverridingSymbol(fromClass) - yield - val tpe = sym.typeAsSeenFrom(fromType) - if sym.isParamAccessor then DecodedMethod.TraitParamAccessor(decodedClass, sym) - else if sym.isSetter then DecodedMethod.SetterAccessor(decodedClass, sym, tpe) - else DecodedMethod.GetterAccessor(decodedClass, sym, tpe) - - private def decodeLazyInit(decodedClass: DecodedClass, name: String): Seq[DecodedMethod] = - val matcher: PartialFunction[Symbol, TermSymbol] = - case sym: TermSymbol if sym.isModuleOrLazyVal && sym.nameStr == name => sym - val fromClass = decodedClass.declarations.collect(matcher).map(DecodedMethod.LazyInit(decodedClass, _)) - def fromTraits = - for - traitSym <- decodedClass.linearization.filter(_.isTrait) - term <- traitSym.declarations.collect(matcher) - if term.isOverridingSymbol(decodedClass) - yield DecodedMethod.LazyInit(decodedClass, term) - fromClass.orIfEmpty(fromTraits) - - private def decodeTraitStaticForwarder( - decodedClass: DecodedClass, - method: binary.Method - ): Option[DecodedMethod.TraitStaticForwarder] = - method.instructions - .collect { - case binary.Instruction.Method(_, owner, name, descriptor, _) if owner == method.declaringClass.name => - method.declaringClass.method(name, descriptor) - } - .singleOpt - .flatten - .map(target => DecodedMethod.TraitStaticForwarder(decode(decodedClass, target))) - - private def decodeOuter(decodedClass: DecodedClass): Option[DecodedMethod.OuterAccessor] = - decodedClass.symbolOpt - .flatMap(_.outerClass) - .map(outerClass => DecodedMethod.OuterAccessor(decodedClass, outerClass.thisType)) - - private def decodeTraitInitializer( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod.ValOrDefDef] = - decodedClass.declarations.collect { - case sym: TermSymbol if sym.name == nme.Constructor => DecodedMethod.ValOrDefDef(decodedClass, sym) - } - - private def decodeValueClassExtension( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod.ValOrDefDef] = - val names = method.unexpandedDecodedNames.map(_.stripSuffix("$extension")) - decodedClass.companionClassSymbol.toSeq.flatMap(_.declarations).collect { - case sym: TermSymbol if names.contains(sym.targetNameStr) && matchSignature(method, sym.declaredType) => - DecodedMethod.ValOrDefDef(decodedClass, sym) - } - - private def decodeStaticForwarder( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod.StaticForwarder] = - decodedClass.companionClassSymbol.toSeq.flatMap(decodeStaticForwarder(decodedClass, _, method)) - - private def decodeStaticForwarder( - decodedClass: DecodedClass, - companionObject: ClassSymbol, - method: binary.Method - ): Seq[DecodedMethod.StaticForwarder] = - method.instructions - .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => - method.declaringClass.classLoader.loadClass(owner).method(name, descriptor) - } - .flatten - .singleOpt - .toSeq - .map(decode) - .collect { - case mixin: DecodedMethod.MixinForwarder => mixin.target - case target => target - } - .map { target => - val declaredType = target.symbolOpt - .map(_.typeAsSeenFrom(companionObject.thisType)) - .getOrElse(target.declaredType) - DecodedMethod.StaticForwarder(decodedClass, target, declaredType) - } - - private def decodeSAMFunctionImpl( - decodedClass: DecodedClass, - symbol: TermSymbol, - parentClass: ClassSymbol, - method: binary.Method - ): Option[DecodedMethod] = - val types = - for - parentCls <- parentClass.linearization.iterator - overridden <- parentCls.declarations.collect { case term: TermSymbol if matchTargetName(method, term) => term } - if overridden.overridingSymbol(parentClass).exists(_.isAbstractMember) - yield DecodedMethod.SAMOrPartialFunctionImpl(decodedClass, overridden, symbol.declaredType) - types.nextOption - - private def decodePartialFunctionImpl( - decodedClass: DecodedClass, - tpe: Type, - method: binary.Method - ): Option[DecodedMethod] = - for sym <- defn.PartialFunctionClass.getNonOverloadedDecl(SimpleName(method.name)) yield - val implTpe = sym.typeAsSeenFrom(SkolemType(tpe)) - DecodedMethod.SAMOrPartialFunctionImpl(decodedClass, sym, implTpe) - - private def decodeBridgesAndMixinForwarders( - decodedClass: DecodedClass, - method: binary.Method - ): Option[DecodedMethod] = - def rec(underlying: DecodedClass): Option[DecodedMethod] = - underlying match - case underlying: DecodedClass.ClassDef => - if !underlying.symbol.isTrait then - decodeBridgesAndMixinForwarders(decodedClass, underlying.symbol, underlying.symbol.thisType, method) - else None - case underlying: DecodedClass.SAMOrPartialFunction => - decodeBridgesAndMixinForwarders(decodedClass, underlying.parentClass, SkolemType(underlying.tpe), method) - case underlying: DecodedClass.InlinedClass => rec(underlying.underlying) - case _: DecodedClass.SyntheticCompanionClass => None - rec(decodedClass) - - private def decodeBridgesAndMixinForwarders( - decodedClass: DecodedClass, - fromClass: ClassSymbol, - fromType: Type, - method: binary.Method - ): Option[DecodedMethod] = - decodeBridges(decodedClass, fromClass, fromType, method) - .orIfEmpty(decodeMixinForwarder(decodedClass, method)) - - private def decodeBridges( - decodedClass: DecodedClass, - fromClass: ClassSymbol, - fromType: Type, - method: binary.Method - ): Option[DecodedMethod] = - method.instructions - .collect { - case binary.Instruction.Method(_, owner, name, descriptor, _) if name == method.name => - method.declaringClass.classLoader.loadClass(owner).method(name, descriptor) - } - .singleOpt - .flatten - .map { binaryTarget => - val target = decode(binaryTarget) - val tpe = target.declaredType.asSeenFrom(fromType, fromClass) - DecodedMethod.Bridge(target, tpe) - } - - private def decodeMixinForwarder( - decodedClass: DecodedClass, - method: binary.Method - ): Option[DecodedMethod.MixinForwarder] = - method.instructions - .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => - method.declaringClass.classLoader.loadClass(owner).method(name, descriptor) - } - .singleOpt - .flatten - .filter(target => target.isStatic && target.declaringClass.isInterface) - .map(decode) - .collect { case staticForwarder: DecodedMethod.TraitStaticForwarder => - DecodedMethod.MixinForwarder(decodedClass, staticForwarder.target) - } - - private def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match - case classDef: DecodedClass.ClassDef => - Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass)) - case _: DecodedClass.SyntheticCompanionClass => Seq.empty - case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol) - case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying) - - private def withCompanionIfExtendsJavaLangEnum(decodedClass: DecodedClass): Seq[ClassSymbol] = - decodedClass.classSymbol.toSeq.flatMap { cls => - if cls.isSubClass(defn.javaLangEnumClass) then Seq(cls) ++ cls.companionClass - else Seq(cls) - } - - private def decodeAdaptedAnonFun(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - if method.instructions.nonEmpty then - val underlying = method.instructions - .collect { - case binary.Instruction.Method(_, owner, name, descriptor, _) if owner == method.declaringClass.name => - method.declaringClass.declaredMethod(name, descriptor) - } - .flatten - .singleOrElse(unexpected(s"$method is not an adapted method: cannot find underlying invocation")) - decodeAnonFunsAndByNameArgs(decodedClass, underlying).map(DecodedMethod.AdaptedFun(_)) - else Seq.empty - - private def decodeAnonFunsAndReduceAmbiguity( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod] = - val candidates = decodeAnonFunsAndByNameArgs(decodedClass, method) - if candidates.size > 1 then - val clashingMethods = method.declaringClass.declaredMethods - .filter(m => m.returnType.zip(method.returnType).forall(_ == _) && m.signedName.name != method.signedName.name) - .collect { case m @ Patterns.AnonFun() if m.name != method.name => m } - .map(m => m -> decodeAnonFunsAndByNameArgs(decodedClass, m).toSet) - .toMap - def reduceAmbiguity( - methods: Map[binary.Method, Set[DecodedMethod]] - ): Map[binary.Method, Set[DecodedMethod]] = - val found = methods.collect { case (m, syms) if syms.size == 1 => syms.head } - val reduced = methods.map { case (m, candidates) => - if candidates.size > 1 then m -> (candidates -- found) - else m -> candidates - } - if reduced.count { case (_, s) => s.size == 1 } == found.size then methods - else reduceAmbiguity(reduced) - reduceAmbiguity(clashingMethods + (method -> candidates.toSet))(method).toSeq - else candidates - - private def decodeAnonFunsAndByNameArgs( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod] = - val anonFuns = decodeLocalMethods(decodedClass, method, Seq(CommonNames.anonFun.toString)) - val byNameArgs = - if method.allParameters.forall(_.isCapture) then decodeByNameArgs(decodedClass, method) - else Seq.empty - reduceAmbiguityOnMethods(anonFuns ++ byNameArgs) - - private def decodeLocalMethods( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod] = - collectLocalMethods(decodedClass, method) { - case fun if names.contains(fun.symbol.name.toString) && matchLiftedFunSignature(method, fun) => - wrapIfInline(fun, DecodedMethod.ValOrDefDef(decodedClass, fun.symbol.asTerm)) - } - - private def reduceAmbiguityOnMethods(syms: Seq[DecodedMethod]): Seq[DecodedMethod] = - if syms.size > 1 then - val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _))) - if reduced.size != 0 then reduced else syms - else syms - - private def enclose(enclosing: DecodedMethod, enclosed: DecodedMethod): Boolean = - (enclosing, enclosed) match - case (enclosing: DecodedMethod.InlinedMethod, enclosed: DecodedMethod.InlinedMethod) => - enclosing.callPos.enclose(enclosed.callPos) || ( - !enclosed.callPos.enclose(enclosing.callPos) && - enclose(enclosing.underlying, enclosed.underlying) - ) - case (enclosing: DecodedMethod.InlinedMethod, enclosed) => - enclosing.callPos.enclose(enclosed.pos) - case (enclosing, enclosed: DecodedMethod.InlinedMethod) => - enclosing.pos.enclose(enclosed.callPos) - case (enclosing, enclosed) => - enclosing.pos.enclose(enclosed.pos) - - private def decodeByNameArgs(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - collectLiftedTrees(decodedClass, method) { case arg: ByNameArg if !arg.isInline => arg } - .collect { - case arg if matchReturnType(arg.tpe, method.returnType) && matchCapture(arg.capture, method.allParameters) => - wrapIfInline(arg, DecodedMethod.ByNameArg(decodedClass, arg.owner, arg.tree, arg.tpe.asInstanceOf)) - } - - private def decodeByNameArgsProxy(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - val explicitByNameArgs = - collectLiftedTrees(decodedClass, method) { case arg: ByNameArg if arg.isInline => arg } - .collect { - case arg if matchReturnType(arg.tpe, method.returnType) && matchCapture(arg.capture, method.allParameters) => - wrapIfInline(arg, DecodedMethod.ByNameArg(decodedClass, arg.owner, arg.tree, arg.tpe.asInstanceOf)) - } - val inlineOverrides = - for - classSym <- decodedClass.classSymbol.toSeq - sym <- classSym.declarations.collect { - case sym: TermSymbol if sym.allOverriddenSymbols.nonEmpty && sym.isInline => sym - } - if method.sourceLines.forall(sym.pos.matchLines) - paramSym <- sym.paramSymbols - resultType <- Seq(paramSym.declaredType).collect { case tpe: ByNameType => tpe.resultType } - if matchReturnType(resultType, method.returnType) - yield - val argTree = Ident(paramSym.name)(paramSym.localRef)(SourcePosition.NoPosition) - DecodedMethod.ByNameArg(decodedClass, sym, argTree, resultType) - explicitByNameArgs ++ inlineOverrides - - private def collectLocalMethods( - decodedClass: DecodedClass, - method: binary.Method - )( - matcher: PartialFunction[LiftedTree[TermSymbol], DecodedMethod] - ): Seq[DecodedMethod] = - collectLiftedTrees(decodedClass, method) { case term: LocalTermDef => term } - .collect(matcher) - - private def decodeSuperArgs( - decodedClass: DecodedClass, - method: binary.Method - ): Seq[DecodedMethod.SuperConstructorArg] = - def matchSuperArg(liftedArg: LiftedTree[Nothing]): Boolean = - val primaryConstructor = liftedArg.owner.asClass.getAllOverloadedDecls(nme.Constructor).head - // a super arg takes the same parameters as its constructor - val sourceParams = extractSourceParams(method, primaryConstructor.declaredType) - val binaryParams = splitBinaryParams(method, sourceParams) - matchReturnType(liftedArg.tpe, method.returnType) && matchCapture(liftedArg.capture, binaryParams.capturedParams) - collectLiftedTrees(decodedClass, method) { case arg: ConstructorArg => arg } - .collect { - case liftedArg if matchSuperArg(liftedArg) => - DecodedMethod.SuperConstructorArg( - decodedClass, - liftedArg.owner.asClass, - liftedArg.tree, - liftedArg.tpe.asInstanceOf - ) - } - - private def decodeLiftedTries(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = - collectLiftedTrees(decodedClass, method) { case tree: LiftedTry => tree } - .collect { - case liftedTry if matchReturnType(liftedTry.tpe, method.returnType) => - wrapIfInline( - liftedTry, - DecodedMethod.LiftedTry(decodedClass, liftedTry.owner, liftedTry.tree, liftedTry.tpe.asInstanceOf) - ) - } - - private def decodeLocalLazyInit( - decodedClass: DecodedClass, - method: binary.Method, - names: Seq[String] - ): Seq[DecodedMethod] = - collectLocalMethods(decodedClass, method) { - case term if term.symbol.isModuleOrLazyVal && names.contains(term.symbol.nameStr) => - wrapIfInline(term, DecodedMethod.LazyInit(decodedClass, term.symbol)) - } - - private def wrapIfInline(liftedTree: LiftedTree[?], decodedMethod: DecodedMethod): DecodedMethod = - liftedTree match - case InlinedFromDef(liftedTree, inlineCall) => - DecodedMethod.InlinedMethod(wrapIfInline(liftedTree, decodedMethod), inlineCall.callTree) - case _ => decodedMethod - - private def collectLiftedTrees[S](decodedClass: DecodedClass, method: binary.Method)( - matcher: PartialFunction[LiftedTree[?], LiftedTree[S]] - ): Seq[LiftedTree[S]] = - val owners = withCompanionIfExtendsAnyVal(decodedClass) - val sourceLines = - if owners.size == 2 && method.allParameters.exists(p => p.name.matches("\\$this\\$\\d+")) then - // workaround of https://github.com/lampepfl/dotty/issues/18816 - method.sourceLines.map(_.last) - else method.sourceLines - owners.flatMap(collectLiftedTrees(_, sourceLines)(matcher)) - - private def collectLiftedTrees[S](owner: Symbol, sourceLines: Option[binary.SourceLines])( - matcher: PartialFunction[LiftedTree[?], LiftedTree[S]] - ): Seq[LiftedTree[S]] = - val recursiveMatcher = new PartialFunction[LiftedTree[?], LiftedTree[S]]: - override def apply(tree: LiftedTree[?]): LiftedTree[S] = tree.asInstanceOf[LiftedTree[S]] - override def isDefinedAt(tree: LiftedTree[?]): Boolean = tree match - case InlinedFromArg(underlying, _, _) => isDefinedAt(underlying) - case InlinedFromDef(underlying, _) => isDefinedAt(underlying) - case _ => matcher.isDefinedAt(tree) - collectAllLiftedTrees(owner).collect(recursiveMatcher).filter(tree => sourceLines.forall(matchLines(tree, _))) - - protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = - LiftedTreeCollector.collect(owner) - - private def matchLines(liftedFun: LiftedTree[?], sourceLines: binary.SourceLines): Boolean = - // we use endsWith instead of == because of tasty-query#434 - val positions = liftedFun.positions.filter(pos => pos.sourceFile.name.endsWith(sourceLines.sourceName)) - sourceLines.tastyLines.forall(line => positions.exists(_.containsLine(line))) - - private def matchTargetName(method: binary.Method, symbol: TermSymbol): Boolean = - method.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr) - - private def matchTargetName(field: binary.Field, symbol: TermSymbol): Boolean = - field.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr) - - private case class SourceParams( - declaredParamNames: Seq[UnsignedTermName], - declaredParamTypes: Seq[Type], - expandedParamTypes: Seq[Type], - returnType: Type - ): - def regularParamTypes: Seq[Type] = declaredParamTypes ++ expandedParamTypes - - private case class BinaryParams( - capturedParams: Seq[binary.Parameter], - declaredParams: Seq[binary.Parameter], - expandedParams: Seq[binary.Parameter], - returnType: Option[binary.Type] - ): - def regularParams = declaredParams ++ expandedParams - - private def matchLiftedFunSignature(method: binary.Method, tree: LiftedTree[TermSymbol]): Boolean = - val sourceParams = extractSourceParams(method, tree.tpe) - val binaryParams = splitBinaryParams(method, sourceParams) - - def matchParamNames: Boolean = - sourceParams.declaredParamNames - .corresponds(binaryParams.declaredParams)((name, javaParam) => name.toString == javaParam.name) - - def matchTypeErasure: Boolean = - sourceParams.regularParamTypes - .corresponds(binaryParams.regularParams)((tpe, javaParam) => matchArgType(tpe, javaParam.`type`, false)) && - matchReturnType(sourceParams.returnType, binaryParams.returnType) - - matchParamNames && matchTypeErasure && matchCapture(tree.capture, binaryParams.capturedParams) - end matchLiftedFunSignature - - private def matchCapture(capture: Seq[String], capturedParams: Seq[binary.Parameter]): Boolean = - val anonymousPattern = "\\$\\d+".r - val evidencePattern = "evidence\\$\\d+".r - def toPattern(variable: String): Regex = - variable match - case anonymousPattern() => "\\$\\d+\\$\\$\\d+".r - case evidencePattern() => "evidence\\$\\d+\\$\\d+".r - case _ => - val encoded = NameTransformer.encode(variable) - s"${Regex.quote(encoded)}(\\$$tailLocal\\d+)?(\\$$lzy\\d+)?\\$$\\d+".r - val patterns = capture.map(toPattern) - def isCapture(param: String) = - patterns.exists(_.unapplySeq(param).nonEmpty) - def isProxy(param: String) = "(.+)\\$proxy\\d+\\$\\d+".r.unapplySeq(param).nonEmpty - def isThisOrOuter(param: String) = "(.+_|\\$)(this|outer)\\$\\d+".r.unapplySeq(param).nonEmpty - def isLazy(param: String) = "(.+)\\$lzy\\d+\\$\\d+".r.unapplySeq(param).nonEmpty - capturedParams.forall(p => isProxy(p.name) || isCapture(p.name) || isThisOrOuter(p.name) || isLazy(p.name)) - - private def matchSignature( - method: binary.Method, - declaredType: TypeOrMethodic, - expandContextFunction: Boolean = true, - captureAllowed: Boolean = true, - asJavaVarargs: Boolean = false, - checkParamNames: Boolean = true, - checkTypeErasure: Boolean = true - ): Boolean = - val sourceParams = extractSourceParams(method, declaredType) - val binaryParams = splitBinaryParams(method, sourceParams) - - def matchParamNames: Boolean = - sourceParams.declaredParamNames - .corresponds(binaryParams.declaredParams)((name, javaParam) => name.toString == javaParam.name) - - def matchTypeErasure: Boolean = - sourceParams.regularParamTypes - .corresponds(binaryParams.regularParams)((tpe, javaParam) => - matchArgType(tpe, javaParam.`type`, asJavaVarargs) - ) && matchReturnType(sourceParams.returnType, binaryParams.returnType) - - (captureAllowed || binaryParams.capturedParams.isEmpty) && - binaryParams.capturedParams.forall(_.isGenerated) && - binaryParams.expandedParams.forall(_.isGenerated) && - sourceParams.regularParamTypes.size == binaryParams.regularParams.size && - (!checkParamNames || matchParamNames) && - (!checkTypeErasure || matchTypeErasure) - end matchSignature - - private def extractSourceParams(method: binary.Method, tpe: TermType): SourceParams = - val (expandedParamTypes, returnType) = - if method.isConstructor && method.declaringClass.isJavaLangEnum then - (List(defn.StringType, defn.IntType), tpe.returnType) - else if !method.isAnonFun then expandContextFunctions(tpe.returnType, acc = Nil) - else (List.empty, tpe.returnType) - SourceParams(tpe.allParamNames, tpe.allParamTypes, expandedParamTypes, returnType) - - /* After code generation, a method ends up with more than its declared parameters. - * - * It has, in order: - * - captured params, - * - declared params, - * - "expanded" params (from java.lang.Enum constructors and uncurried context function types). - * - * We can only check the names of declared params. - * We can check the (erased) type of declared and expanded params; together we call them "regular" params. - */ - private def splitBinaryParams(method: binary.Method, sourceParams: SourceParams): BinaryParams = - val (capturedParams, regularParams) = - method.allParameters.splitAt(method.allParameters.size - sourceParams.regularParamTypes.size) - val (declaredParams, expandedParams) = regularParams.splitAt(sourceParams.declaredParamTypes.size) - BinaryParams(capturedParams, declaredParams, expandedParams, method.returnType) - - private def expandContextFunctions(tpe: Type, acc: List[Type]): (List[Type], Type) = - tpe.safeDealias match - case Some(tpe: AppliedType) if tpe.tycon.isContextFunction => - val argsAsTypes = tpe.args.map(_.highIfWildcard) - expandContextFunctions(argsAsTypes.last, acc ::: argsAsTypes.init) - case _ => (acc, tpe) - - private lazy val scalaPrimitivesToJava: Map[ClassSymbol, String] = Map( - defn.BooleanClass -> "boolean", - defn.ByteClass -> "byte", - defn.CharClass -> "char", - defn.DoubleClass -> "double", - defn.FloatClass -> "float", - defn.IntClass -> "int", - defn.LongClass -> "long", - defn.ShortClass -> "short", - defn.UnitClass -> "void", - defn.NullClass -> "scala.runtime.Null$" - ) - - private def matchSetterArgType(scalaVarType: TypeOrMethodic, javaSetterParamType: binary.Type): Boolean = - scalaVarType match - case scalaVarType: Type => - scalaVarType.erasedAsArgType(asJavaVarargs = false).exists(matchType(_, javaSetterParamType)) - case _: MethodicType => false - - private def matchArgType(scalaType: Type, javaType: binary.Type, asJavaVarargs: Boolean): Boolean = - scalaType.erasedAsArgType(asJavaVarargs).exists(matchType(_, javaType)) - - private def matchReturnType(scalaType: TermType, javaType: Option[binary.Type]): Boolean = - scalaType match - case scalaType: Type => javaType.forall(jt => scalaType.erasedAsReturnType.exists(matchType(_, jt))) - case _: MethodicType | _: PackageRef => false - - private lazy val dollarDigitsMaybeDollarAtEndRegex = "\\$\\d+\\$?$".r - - private def matchType( - scalaType: ErasedTypeRef, - javaType: binary.Type - ): Boolean = - def rec(scalaType: ErasedTypeRef, javaType: String): Boolean = - scalaType match - case ErasedTypeRef.ArrayTypeRef(base, dimensions) => - javaType.endsWith("[]" * dimensions) && - rec(base, javaType.dropRight(2 * dimensions)) - case ErasedTypeRef.ClassRef(scalaClass) => - scalaPrimitivesToJava.get(scalaClass) match - case Some(javaPrimitive) => javaPrimitive == javaType - case None => matchClassType(scalaClass, javaType, nested = false) - rec(scalaType, javaType.name) - - private def matchClassType(scalaClass: ClassSymbol, javaType: String, nested: Boolean): Boolean = - def encodedName(nested: Boolean): String = scalaClass.name match - case ObjectClassTypeName(underlying) if nested => NameTransformer.encode(underlying.toString()) - case name => NameTransformer.encode(name.toString()) - scalaClass.owner match - case owner: PackageSymbol => - javaType == owner.fullName.toString() + "." + encodedName(nested) - case owner: ClassSymbol => - val encodedName1 = encodedName(nested) - javaType.endsWith("$" + encodedName1) && - matchClassType(owner, javaType.dropRight(1 + encodedName1.length()), nested = true) - case owner: TermOrTypeSymbol => - dollarDigitsMaybeDollarAtEndRegex.findFirstIn(javaType).exists { suffix => - val prefix = javaType.stripSuffix(suffix) - val encodedName1 = encodedName(nested = true) - prefix.endsWith("$" + encodedName1) && { - val ownerJavaType = prefix.dropRight(1 + encodedName1.length()) - enclosingClassOwners(owner).exists(matchClassType(_, ownerJavaType, nested = true)) - } - } - - private def enclosingClassOwners(sym: TermOrTypeSymbol): List[ClassSymbol] = - sym.owner match - case owner: ClassSymbol => owner :: enclosingClassOwners(owner) - case owner: TermOrTypeSymbol => enclosingClassOwners(owner) - case owner: PackageSymbol => Nil - - private def decodeCapturedLzyVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] = - decodedMethod match - case m: DecodedMethod.LazyInit if m.symbol.nameStr == name => - Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) - case m: DecodedMethod.ValOrDefDef if m.symbol.nameStr == name => - Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) - case _ => decodeCapturedVariable(decodedMethod, name) - - private def decodeCapturedVariable(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] = - for - metTree <- decodedMethod.treeOpt.toSeq - sym <- CaptureCollector.collectCaptures(metTree) - if name == sym.nameStr - yield DecodedVariable.CapturedVariable(decodedMethod, sym) - - private def decodeValDef( - decodedMethod: DecodedMethod, - variable: binary.Variable, - sourceLine: Int - ): Seq[DecodedVariable] = decodedMethod.symbolOpt match - case Some(owner: TermSymbol) if owner.sourceLanguage == SourceLanguage.Scala2 => - for - paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten - if variable.name == paramSym.nameStr - yield DecodedVariable.ValDef(decodedMethod, paramSym) - case _ => - for - tree <- decodedMethod.treeOpt.toSeq - localVar <- VariableCollector.collectVariables(tree).toSeq - if variable.name == localVar.sym.nameStr && - (decodedMethod.isGenerated || !variable.declaringMethod.sourceLines.exists(x => - localVar.sourceFile.name.endsWith(x.sourceName) - ) || (localVar.startLine <= sourceLine && sourceLine <= localVar.endLine)) - yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) - - private def decodeSAMOrPartialFun( - decodedMethod: DecodedMethod, - owner: TermSymbol, - variable: binary.Variable, - sourceLine: Int - ): Seq[DecodedVariable] = - if owner.sourceLanguage == SourceLanguage.Scala2 then - val x = - for - paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten - if variable.name == paramSym.nameStr - yield DecodedVariable.ValDef(decodedMethod, paramSym) - x.toSeq - else - for - localVar <- owner.tree.toSeq.flatMap(t => VariableCollector.collectVariables(t)) - if variable.name == localVar.sym.nameStr - yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) - - private def unexpandedSymName(sym: Symbol): String = - "(.+)\\$\\w+".r.unapplySeq(sym.nameStr).map(xs => xs(0)).getOrElse(sym.nameStr) - - private def decodeProxy( - decodedMethod: DecodedMethod, - name: String - ): Seq[DecodedVariable] = - for - metTree <- decodedMethod.treeOpt.toSeq - localVar <- VariableCollector.collectVariables(metTree) - if name == localVar.sym.nameStr - yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) - - private def decodeInlinedThis( - decodedMethod: DecodedMethod, - variable: binary.Variable - ): Seq[DecodedVariable] = - val decodedClassSym = variable.`type` match - case cls: binary.ClassType => decode(cls).classSymbol - case _ => None - for - metTree <- decodedMethod.treeOpt.toSeq - decodedClassSym <- decodedClassSym.toSeq - if VariableCollector.collectVariables(metTree, sym = decodedMethod.symbolOpt).exists { localVar => - localVar.sym == decodedClassSym - } - yield DecodedVariable.This(decodedMethod, decodedClassSym.thisType) - - private def decodeDollarThis( - decodedMethod: DecodedMethod - ): Seq[DecodedVariable] = - decodedMethod match - case _: DecodedMethod.TraitStaticForwarder => - decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) - case _ => - for - classSym <- decodedMethod.owner.companionClassSymbol.toSeq - if classSym.isSubClass(defn.AnyValClass) - sym <- classSym.declarations.collect { - case sym: TermSymbol if sym.isVal && !sym.isMethod => sym - } - yield DecodedVariable.AnyValThis(decodedMethod, sym) + new BinaryDecoder(using ctx): + // we cache successes and failures, which can be quite expensive too + private val classCache: TrieMap[String, Try[DecodedClass]] = TrieMap.empty + private val methodCache: TrieMap[(String, binary.SignedName), Try[DecodedMethod]] = TrieMap.empty + private val liftedTreesCache: TrieMap[Symbol, Try[Seq[LiftedTree[?]]]] = TrieMap.empty + + override def decode(cls: binary.BinaryClass): DecodedClass = + classCache.getOrElseUpdate(cls.name, Try(super.decode(cls))).get + + override def decode(method: binary.Method): DecodedMethod = + methodCache.getOrElseUpdate((method.declaringClass.name, method.signedName), Try(super.decode(method))).get + + override protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = + liftedTreesCache.getOrElseUpdate(owner, Try(super.collectAllLiftedTrees(owner))).get + end new + +class BinaryDecoder(using Context, ThrowOrWarn) + extends BinaryClassDecoder, + BinaryMethodDecoder, + BinaryFieldDecoder, + BinaryVariableDecoder: + def context = ctx diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryFieldDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryFieldDecoder.scala new file mode 100644 index 0000000..19f988c --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryFieldDecoder.scala @@ -0,0 +1,87 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.internal.* +import tastyquery.Contexts.* +import tastyquery.Symbols.* + +trait BinaryFieldDecoder(using Context, ThrowOrWarn): + self: BinaryClassDecoder => + + def decode(field: binary.Field): DecodedField = + val decodedClass = decode(field.declaringClass) + decode(decodedClass, field) + + def decode(decodedClass: DecodedClass, field: binary.Field): DecodedField = + def tryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] = + f.applyOrElse(field, _ => Seq.empty[DecodedField]) + + extension (xs: Seq[DecodedField]) + def orTryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] = + if xs.nonEmpty then xs else f.applyOrElse(field, _ => Seq.empty[DecodedField]) + val decodedFields = + tryDecode { + case Patterns.LazyVal(name) => + for + owner <- decodedClass.classSymbol.toSeq ++ decodedClass.linearization.filter(_.isTrait) + sym <- owner.declarations.collect { + case sym: TermSymbol if sym.nameStr == name && sym.isModuleOrLazyVal => sym + } + yield DecodedField.ValDef(decodedClass, sym) + case Patterns.Module() => + decodedClass.classSymbol.flatMap(_.moduleValue).map(DecodedField.ModuleVal(decodedClass, _)).toSeq + case Patterns.Offset(nbr) => + Seq(DecodedField.LazyValOffset(decodedClass, nbr, defn.LongType)) + case Patterns.Outer() => decodeOuter(decodedClass) + case Patterns.SerialVersionUID() => + Seq(DecodedField.SerialVersionUID(decodedClass, defn.LongType)) + case Patterns.LazyValBitmap(name) => + Seq(DecodedField.LazyValBitmap(decodedClass, defn.BooleanType, name)) + case Patterns.AnyValCapture() => + for + classSym <- decodedClass.symbolOpt.toSeq + outerClass <- classSym.outerClass.toSeq + if outerClass.isSubClass(defn.AnyValClass) + sym <- outerClass.declarations.collect { + case sym: TermSymbol if sym.isVal && !sym.isMethod => sym + } + yield DecodedField.Capture(decodedClass, sym) + case Patterns.CapturedLzyVariable(names) => decodeCapture(decodedClass, names) + case Patterns.Capture(names) => decodeCapture(decodedClass, names) + case _ if field.isStatic && decodedClass.isJava => + for + owner <- decodedClass.companionClassSymbol.toSeq + sym <- owner.declarations.collect { case sym: TermSymbol if sym.nameStr == field.name => sym } + yield DecodedField.ValDef(decodedClass, sym) + }.orTryDecode { case _ => + for + owner <- withCompanionIfExtendsJavaLangEnum(decodedClass) ++ decodedClass.linearization.filter(_.isTrait) + sym <- owner.declarations.collect { + case sym: TermSymbol if matchTargetName(field, sym) && !sym.isMethod => sym + } + yield DecodedField.ValDef(decodedClass, sym) + } + decodedFields.singleOrThrow(field) + end decode + + private def decodeOuter(decodedClass: DecodedClass): Seq[DecodedField] = + decodedClass.symbolOpt + .flatMap(_.outerClass) + .map(outerClass => DecodedField.Outer(decodedClass, outerClass.selfType)) + .toSeq + + private def decodeCapture(decodedClass: DecodedClass, names: Seq[String]): Seq[DecodedField] = + for + clsSym <- decodedClass.symbolOpt.toSeq + scope = scoper.getScope(clsSym) + capturedSym <- scope.capturedVariables + if names.contains(capturedSym.nameStr) + yield DecodedField.Capture(decodedClass, capturedSym) + + private def withCompanionIfExtendsJavaLangEnum(decodedClass: DecodedClass): Seq[ClassSymbol] = + decodedClass.classSymbol.toSeq.flatMap { cls => + if cls.isSubClass(Definitions.javaLangEnumClass) then Seq(cls) ++ cls.companionClass + else Seq(cls) + } + + private def matchTargetName(field: binary.Field, symbol: TermSymbol): Boolean = + field.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr) diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryMethodDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryMethodDecoder.scala new file mode 100644 index 0000000..be3f53a --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryMethodDecoder.scala @@ -0,0 +1,808 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.internal.* +import tastyquery.Contexts.* +import tastyquery.Names.* +import tastyquery.SourcePosition +import tastyquery.Symbols.* +import tastyquery.Trees.* +import tastyquery.Types.* + +import scala.util.matching.Regex + +trait BinaryMethodDecoder(using Context, ThrowOrWarn): + self: BinaryClassDecoder => + def decode(method: binary.Method): DecodedMethod = + val decodedClass = decode(method.declaringClass) + decode(decodedClass, method) + + def decode(decodedClass: DecodedClass, method: binary.Method): DecodedMethod = + def tryDecode(f: PartialFunction[binary.Method, Seq[DecodedMethod]]): Seq[DecodedMethod] = + f.applyOrElse(method, _ => Seq.empty[DecodedMethod]) + + extension (xs: Seq[DecodedMethod]) + def orTryDecode(f: PartialFunction[binary.Method, Seq[DecodedMethod]]): Seq[DecodedMethod] = + if xs.nonEmpty then xs else f.applyOrElse(method, _ => Seq.empty[DecodedMethod]) + val candidates = + tryDecode { + // static and/or bridge + case Patterns.AdaptedAnonFun() => decodeAdaptedAnonFun(decodedClass, method) + // bridge or standard + case Patterns.SpecializedMethod(names) => decodeSpecializedMethod(decodedClass, method, names) + // bridge only + case m if m.isBridge => decodeBridgesAndMixinForwarders(decodedClass, method).toSeq + // static or standard + case Patterns.AnonFun() => decodeAnonFunsAndReduceAmbiguity(decodedClass, method) + case Patterns.ByNameArgProxy() => decodeByNameArgsProxy(decodedClass, method) + case Patterns.SuperArg() => decodeSuperArgs(decodedClass, method) + case Patterns.LiftedTree() => decodeLiftedTries(decodedClass, method) + case Patterns.LocalLazyInit(names) => decodeLocalLazyInit(decodedClass, method, names) + // static only + case Patterns.TraitInitializer() => decodeTraitInitializer(decodedClass, method) + case Patterns.DeserializeLambda() => + Seq(DecodedMethod.DeserializeLambda(decodedClass, Definitions.DeserializeLambdaType)) + case Patterns.TraitStaticForwarder() => decodeTraitStaticForwarder(decodedClass, method).toSeq + case m if m.isStatic && decodedClass.isJava => decodeStaticJavaMethods(decodedClass, method) + // cannot be static anymore + case Patterns.LazyInit(name) => decodeLazyInit(decodedClass, name) + case Patterns.Outer() => decodeOuter(decodedClass).toSeq + case Patterns.ParamForwarder(names) => decodeParamForwarder(decodedClass, method, names) + case Patterns.TraitSetter(name) => decodeTraitSetter(decodedClass, method, name) + case Patterns.Setter(names) => + decodeStandardMethods(decodedClass, method).orIfEmpty(decodeSetter(decodedClass, method, names)) + case Patterns.SuperAccessor(names) => decodeSuperAccessor(decodedClass, method, names) + } + .orTryDecode { case Patterns.ValueClassExtension() => decodeValueClassExtension(decodedClass, method) } + .orTryDecode { case Patterns.InlineAccessor(names) => decodeInlineAccessor(decodedClass, method, names).toSeq } + .orTryDecode { case Patterns.LocalMethod(names) => decodeLocalMethods(decodedClass, method, names) } + .orTryDecode { + case m if m.isStatic => decodeStaticForwarder(decodedClass, method) + case _ => decodeStandardMethods(decodedClass, method) + } + + candidates.singleOrThrow(method) + end decode + + private def decodeStaticJavaMethods(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + decodedClass.companionClassSymbol.toSeq + .flatMap(_.declarations) + .collect { + case sym: TermSymbol + if matchTargetName(method, sym) && matchSignature(method, sym.declaredType, checkParamNames = false) => + DecodedMethod.ValOrDefDef(decodedClass, sym) + } + + private def decodeStandardMethods(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + def rec(underlying: DecodedClass): Seq[DecodedMethod] = + underlying match + case anonFun: DecodedClass.SAMOrPartialFunction => + if method.isConstructor then Seq(DecodedMethod.SAMOrPartialFunctionConstructor(decodedClass, anonFun.tpe)) + else if anonFun.parentClass == Definitions.PartialFunctionClass then + decodePartialFunctionImpl(decodedClass, anonFun.tpe, method).toSeq + else decodeSAMFunctionImpl(decodedClass, anonFun.symbol, anonFun.parentClass, method).toSeq + case underlying: DecodedClass.ClassDef => decodeInstanceMethods(decodedClass, underlying.symbol, method) + case _: DecodedClass.SyntheticCompanionClass => Seq.empty + case inlined: DecodedClass.InlinedClass => rec(inlined.underlying) + rec(decodedClass) + + private def decodeParamForwarder( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod.ValOrDefDef] = + decodedClass.declarations.collect { + case sym: TermSymbol if names.contains(sym.targetNameStr) && matchSignature(method, sym.declaredType) => + DecodedMethod.ValOrDefDef(decodedClass, sym) + } + + private def decodeTraitSetter( + decodedClass: DecodedClass, + method: binary.Method, + name: String + ): Seq[DecodedMethod.SetterAccessor] = + for + traitSym <- decodedClass.linearization.filter(_.isTrait) + if method.decodedName.contains("$" + traitSym.nameStr + "$") + sym <- traitSym.declarations.collect { + case sym: TermSymbol if sym.targetNameStr == name && !sym.isMethod && !sym.isAbstractMember => sym + } + paramType <- decodedClass.thisType.map(sym.typeAsSeenFrom).collect { case tpe: Type => tpe } + yield + val tpe = MethodType(List(SimpleName("x$1")), List(paramType), defn.UnitType) + DecodedMethod.SetterAccessor(decodedClass, sym, tpe) + + private def decodeSetter( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod.SetterAccessor] = + for + param <- method.parameters.lastOption.toSeq + sym <- decodeFields(decodedClass, param.`type`, names) + yield + val tpe = MethodType(List(SimpleName("x$1")), List(sym.declaredType.asInstanceOf[Type]), defn.UnitType) + DecodedMethod.SetterAccessor(decodedClass, sym, tpe) + + private def decodeFields( + decodedClass: DecodedClass, + binaryType: binary.Type, + names: Seq[String] + ): Seq[TermSymbol] = + def matchType0(sym: TermSymbol): Boolean = matchSetterArgType(sym.declaredType, binaryType) + decodedClass.declarations.collect { + case sym: TermSymbol if !sym.isMethod && names.exists(sym.targetNameStr == _) && matchType0(sym) => + sym + } + + private def decodeSuperAccessor( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod] = + for + traitSym <- decodedClass.linearization.filter(_.isTrait) + if method.decodedName.contains("$" + traitSym.nameStr + "$") + sym <- traitSym.declarations.collect { + case sym: TermSymbol if names.contains(sym.targetNameStr) && !sym.isAbstractMember => sym + } + expectedTpe <- decodedClass.thisType.map(sym.typeAsSeenFrom(_)) + if matchSignature(method, expectedTpe) + yield DecodedMethod.SuperAccessor(decodedClass, sym, expectedTpe) + + private def decodeSpecializedMethod( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod.SpecializedMethod] = + decodedClass.declarations.collect { + case sym: TermSymbol + if names.contains(sym.targetNameStr) && + matchSignature( + method, + sym.declaredType, + captureAllowed = false, + checkParamNames = false, + checkTypeErasure = false + ) && + // hack: in Scala 3 only overriding symbols can be specialized (Function and Tuple) + sym.allOverriddenSymbols.nonEmpty => + DecodedMethod.SpecializedMethod(decodedClass, sym) + } + + private def decodeInlineAccessor( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod] = + val classLoader = method.declaringClass.classLoader + val methodAccessors = method.instructions + .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => + classLoader.loadClass(owner).method(name, descriptor) + } + .singleOpt + .flatten + .map { binaryTarget => + val target = decode(binaryTarget) + // val tpe = target.declaredType.asSeenFrom(fromType, fromClass) + DecodedMethod.InlineAccessor(decodedClass, target) + } + def singleFieldInstruction(f: binary.Instruction.Field => Boolean) = method.instructions + .collect { case instr: binary.Instruction.Field => instr } + .singleOpt + .filter(f) + .toSeq + def fieldSetters = + val expectedNames = names.map(_.stripSuffix("_=")).distinct + for + instr <- singleFieldInstruction(f => f.isPut && f.unexpandedDecodedNames.exists(expectedNames.contains)) + binaryField <- classLoader.loadClass(instr.owner).declaredField(instr.name).toSeq + fieldOwner = decode(binaryField.declaringClass) + sym <- decodeFields(fieldOwner, binaryField.`type`, instr.unexpandedDecodedNames) + yield + val tpe = MethodType(List(SimpleName("x$1")), List(sym.declaredType.asInstanceOf[Type]), defn.UnitType) + val decodedTarget = DecodedMethod.SetterAccessor(fieldOwner, sym, tpe) + DecodedMethod.InlineAccessor(decodedClass, decodedTarget) + def fieldGetters = + for + instr <- singleFieldInstruction(f => !f.isPut && f.unexpandedDecodedNames.exists(names.contains)) + binaryField <- classLoader.loadClass(instr.owner).declaredField(instr.name).toSeq + fieldOwner = decode(binaryField.declaringClass) + sym <- decodeFields(fieldOwner, binaryField.`type`, instr.unexpandedDecodedNames) + yield DecodedMethod.InlineAccessor(decodedClass, DecodedMethod.ValOrDefDef(fieldOwner, sym)) + def moduleAccessors = + for + instr <- singleFieldInstruction(_.name == "MODULE$") + targetClass = decode(classLoader.loadClass(instr.owner)) + targetClassSym <- targetClass.classSymbol + targetTermSym <- targetClassSym.moduleValue + yield DecodedMethod.InlineAccessor(decodedClass, DecodedMethod.ValOrDefDef(targetClass, targetTermSym)) + def valueClassAccessors = + if method.instructions.isEmpty && method.isExtensionMethod then + for + companionClass <- decodedClass.companionClass.toSeq + param <- method.parameters.lastOption.toSeq + sym <- decodeFields(companionClass, param.`type`, names.map(_.stripSuffix("$extension"))) + yield + val decodedTarget = DecodedMethod.ValOrDefDef(decodedClass, sym) + DecodedMethod.InlineAccessor(decodedClass, decodedTarget) + else Seq.empty + methodAccessors.toSeq + .orIfEmpty(fieldSetters) + .orIfEmpty(fieldGetters) + .orIfEmpty(moduleAccessors.toSeq) + .orIfEmpty(valueClassAccessors) + + private def decodeInstanceMethods( + decodedClass: DecodedClass, + classSymbol: ClassSymbol, + method: binary.Method + ): Seq[DecodedMethod] = + if method.isConstructor && classSymbol.isSubClass(defn.AnyValClass) then + classSymbol.getAllOverloadedDecls(SimpleName("")).map(DecodedMethod.ValOrDefDef(decodedClass, _)) + else + val isJava = decodedClass.isJava + val fromClass = classSymbol.declarations + .collect { case sym: TermSymbol if matchTargetName(method, sym) => sym } + .collect { + case sym + if matchSignature( + method, + sym.declaredType, + asJavaVarargs = isJava, + captureAllowed = !isJava, + checkParamNames = !isJava + ) => + DecodedMethod.ValOrDefDef(decodedClass, sym) + case sym if !isJava && matchSignature(method, sym.declaredType, asJavaVarargs = true) => + DecodedMethod.Bridge(DecodedMethod.ValOrDefDef(decodedClass, sym), sym.declaredType) + } + fromClass.orIfEmpty(decodeAccessorsFromTraits(decodedClass, classSymbol, method)) + + private def decodeAccessorsFromTraits( + decodedClass: DecodedClass, + classSymbol: ClassSymbol, + method: binary.Method + ): Seq[DecodedMethod] = + if classSymbol.isTrait then Seq.empty + else decodeAccessorsFromTraits(decodedClass, classSymbol, classSymbol.thisType, method) + + private def decodeAccessorsFromTraits( + decodedClass: DecodedClass, + fromClass: ClassSymbol, + fromType: Type, + method: binary.Method + ): Seq[DecodedMethod] = + for + traitSym <- fromClass.linearization.filter(_.isTrait) + if !method.isExpanded || method.decodedName.contains("$" + traitSym.nameStr + "$") + sym <- traitSym.declarations + .collect { + case sym: TermSymbol if matchTargetName(method, sym) && matchSignature(method, sym.declaredType) => sym + } + if method.isExpanded == sym.isPrivate + if sym.isParamAccessor || sym.isSetter || !sym.isMethod + if sym.isOverridingSymbol(fromClass) + yield + val tpe = sym.typeAsSeenFrom(fromType) + if sym.isParamAccessor then DecodedMethod.TraitParamAccessor(decodedClass, sym) + else if sym.isSetter then DecodedMethod.SetterAccessor(decodedClass, sym, tpe) + else DecodedMethod.GetterAccessor(decodedClass, sym, tpe) + + private def decodeLazyInit(decodedClass: DecodedClass, name: String): Seq[DecodedMethod] = + val matcher: PartialFunction[Symbol, TermSymbol] = + case sym: TermSymbol if sym.isModuleOrLazyVal && sym.nameStr == name => sym + val fromClass = decodedClass.declarations.collect(matcher).map(DecodedMethod.LazyInit(decodedClass, _)) + def fromTraits = + for + traitSym <- decodedClass.linearization.filter(_.isTrait) + term <- traitSym.declarations.collect(matcher) + if term.isOverridingSymbol(decodedClass) + yield DecodedMethod.LazyInit(decodedClass, term) + fromClass.orIfEmpty(fromTraits) + + private def decodeTraitStaticForwarder( + decodedClass: DecodedClass, + method: binary.Method + ): Option[DecodedMethod.TraitStaticForwarder] = + method.instructions + .collect { + case binary.Instruction.Method(_, owner, name, descriptor, _) if owner == method.declaringClass.name => + method.declaringClass.declaredMethod(name, descriptor) + } + .singleOpt + .flatten + .map(target => DecodedMethod.TraitStaticForwarder(decode(decodedClass, target))) + + private def decodeOuter(decodedClass: DecodedClass): Option[DecodedMethod.OuterAccessor] = + decodedClass.symbolOpt + .flatMap(_.outerClass) + .map(outerClass => DecodedMethod.OuterAccessor(decodedClass, outerClass.thisType)) + + private def decodeTraitInitializer( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod.ValOrDefDef] = + decodedClass.declarations.collect { + case sym: TermSymbol if sym.name == nme.Constructor => DecodedMethod.ValOrDefDef(decodedClass, sym) + } + + private def decodeValueClassExtension( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod.ValOrDefDef] = + val names = method.unexpandedDecodedNames.map(_.stripSuffix("$extension")) + decodedClass.companionClassSymbol.toSeq.flatMap(_.declarations).collect { + case sym: TermSymbol if names.contains(sym.targetNameStr) && matchSignature(method, sym.declaredType) => + DecodedMethod.ValOrDefDef(decodedClass, sym) + } + + private def decodeStaticForwarder( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod.StaticForwarder] = + decodedClass.companionClassSymbol.toSeq.flatMap(decodeStaticForwarder(decodedClass, _, method)) + + private def decodeStaticForwarder( + decodedClass: DecodedClass, + companionObject: ClassSymbol, + method: binary.Method + ): Seq[DecodedMethod.StaticForwarder] = + method.instructions + .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => + method.declaringClass.classLoader.loadClass(owner).method(name, descriptor) + } + .flatten + .singleOpt + .toSeq + .map(decode) + .collect { + case mixin: DecodedMethod.MixinForwarder => mixin.target + case target => target + } + .map { target => + val declaredType = target.symbolOpt + .map(_.typeAsSeenFrom(companionObject.thisType)) + .getOrElse(target.declaredType) + DecodedMethod.StaticForwarder(decodedClass, target, declaredType) + } + + private def decodeSAMFunctionImpl( + decodedClass: DecodedClass, + symbol: TermSymbol, + parentClass: ClassSymbol, + method: binary.Method + ): Option[DecodedMethod] = + val types = + for + parentCls <- parentClass.linearization.iterator + overridden <- parentCls.declarations.collect { case term: TermSymbol if matchTargetName(method, term) => term } + if overridden.overridingSymbol(parentClass).exists(_.isAbstractMember) + yield DecodedMethod.SAMOrPartialFunctionImpl(decodedClass, overridden, symbol.declaredType) + types.nextOption + + private def decodePartialFunctionImpl( + decodedClass: DecodedClass, + tpe: Type, + method: binary.Method + ): Option[DecodedMethod] = + for sym <- Definitions.PartialFunctionClass.getNonOverloadedDecl(SimpleName(method.name)) yield + val implTpe = sym.typeAsSeenFrom(SkolemType(tpe)) + DecodedMethod.SAMOrPartialFunctionImpl(decodedClass, sym, implTpe) + + private def decodeBridgesAndMixinForwarders( + decodedClass: DecodedClass, + method: binary.Method + ): Option[DecodedMethod] = + def rec(underlying: DecodedClass): Option[DecodedMethod] = + underlying match + case underlying: DecodedClass.ClassDef => + if !underlying.symbol.isTrait then + decodeBridgesAndMixinForwarders(decodedClass, underlying.symbol, underlying.symbol.thisType, method) + else None + case underlying: DecodedClass.SAMOrPartialFunction => + decodeBridgesAndMixinForwarders(decodedClass, underlying.parentClass, SkolemType(underlying.tpe), method) + case underlying: DecodedClass.InlinedClass => rec(underlying.underlying) + case _: DecodedClass.SyntheticCompanionClass => None + rec(decodedClass) + + private def decodeBridgesAndMixinForwarders( + decodedClass: DecodedClass, + fromClass: ClassSymbol, + fromType: Type, + method: binary.Method + ): Option[DecodedMethod] = + decodeBridges(decodedClass, fromClass, fromType, method) + .orIfEmpty(decodeMixinForwarder(decodedClass, method)) + + private def decodeBridges( + decodedClass: DecodedClass, + fromClass: ClassSymbol, + fromType: Type, + method: binary.Method + ): Option[DecodedMethod] = + method.instructions + .collect { + case binary.Instruction.Method(_, owner, name, descriptor, _) if name == method.name => + method.declaringClass.classLoader.loadClass(owner).method(name, descriptor) + } + .singleOpt + .flatten + .map { binaryTarget => + val target = decode(binaryTarget) + val tpe = target.declaredType.asSeenFrom(fromType, fromClass) + DecodedMethod.Bridge(target, tpe) + } + + private def decodeMixinForwarder( + decodedClass: DecodedClass, + method: binary.Method + ): Option[DecodedMethod.MixinForwarder] = + method.instructions + .collect { case binary.Instruction.Method(_, owner, name, descriptor, _) => + method.declaringClass.classLoader.loadClass(owner).declaredMethod(name, descriptor) + } + .singleOpt + .flatten + .filter(target => target.isStatic && target.declaringClass.isInterface) + .map(decode) + .collect { case staticForwarder: DecodedMethod.TraitStaticForwarder => + DecodedMethod.MixinForwarder(decodedClass, staticForwarder.target) + } + + private def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match + case classDef: DecodedClass.ClassDef => + Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass)) + case _: DecodedClass.SyntheticCompanionClass => Seq.empty + case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol) + case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying) + + private def decodeAdaptedAnonFun(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + if method.instructions.nonEmpty then + val underlying = method.instructions + .collect { + case binary.Instruction.Method(_, owner, name, descriptor, _) if owner == method.declaringClass.name => + method.declaringClass.declaredMethod(name, descriptor) + } + .flatten + .singleOrElse(unexpected(s"$method is not an adapted method: cannot find underlying invocation")) + decodeAnonFunsAndByNameArgs(decodedClass, underlying).map(DecodedMethod.AdaptedFun(_)) + else Seq.empty + + private def decodeAnonFunsAndReduceAmbiguity( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod] = + val candidates = decodeAnonFunsAndByNameArgs(decodedClass, method) + if candidates.size > 1 then + val clashingMethods = method.declaringClass.declaredMethods + .filter(m => m.returnType.zip(method.returnType).forall(_ == _) && m.name != method.name) + .collect { case m @ Patterns.AnonFun() if m.name != method.name => m } + .map(m => m -> decodeAnonFunsAndByNameArgs(decodedClass, m).toSet) + .toMap + def reduceAmbiguity( + methods: Map[binary.Method, Set[DecodedMethod]] + ): Map[binary.Method, Set[DecodedMethod]] = + val found = methods.collect { case (m, syms) if syms.size == 1 => syms.head } + val reduced = methods.map { case (m, candidates) => + if candidates.size > 1 then m -> (candidates -- found) + else m -> candidates + } + if reduced.count { case (_, s) => s.size == 1 } == found.size then methods + else reduceAmbiguity(reduced) + reduceAmbiguity(clashingMethods + (method -> candidates.toSet))(method).toSeq + else candidates + + private def decodeAnonFunsAndByNameArgs( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod] = + val anonFuns = decodeLocalMethods(decodedClass, method, Seq(CommonNames.anonFun.toString)) + val byNameArgs = + if method.parameters.forall(_.isCapturedParam) then decodeByNameArgs(decodedClass, method) + else Seq.empty + reduceAmbiguity(anonFuns ++ byNameArgs) + + private def decodeLocalMethods( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod] = + collectLocalMethods(decodedClass, method) { + case fun if names.contains(fun.symbol.name.toString) && matchLiftedFunSignature(method, fun) => + wrapIfInline(fun, DecodedMethod.ValOrDefDef(decodedClass, fun.symbol.asTerm)) + } + + private def decodeByNameArgs(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + collectLiftedTrees(decodedClass, method) { case arg: ByNameArg if !arg.isInline => arg } + .collect { + case arg if matchReturnType(arg.tpe, method.returnType) && matchCapture(arg, method.parameters) => + wrapIfInline(arg, DecodedMethod.ByNameArg(decodedClass, arg.owner, arg.tree, arg.tpe.asInstanceOf)) + } + + private def decodeByNameArgsProxy(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + val explicitByNameArgs = + collectLiftedTrees(decodedClass, method) { case arg: ByNameArg if arg.isInline => arg } + .collect { + case arg if matchReturnType(arg.tpe, method.returnType) && matchCapture(arg, method.parameters) => + wrapIfInline(arg, DecodedMethod.ByNameArg(decodedClass, arg.owner, arg.tree, arg.tpe.asInstanceOf)) + } + val inlineOverrides = + for + classSym <- decodedClass.classSymbol.toSeq + sym <- classSym.declarations.collect { + case sym: TermSymbol if sym.allOverriddenSymbols.nonEmpty && sym.isInline => sym + } + if method.sourceLines.forall(sym.pos.matchLines) + paramSym <- sym.paramSymbols + resultType <- Seq(paramSym.declaredType).collect { case tpe: ByNameType => tpe.resultType } + if matchReturnType(resultType, method.returnType) + yield + val argTree = Ident(paramSym.name)(paramSym.localRef)(SourcePosition.NoPosition) + DecodedMethod.ByNameArg(decodedClass, sym, argTree, resultType) + explicitByNameArgs ++ inlineOverrides + + private def decodeSuperArgs( + decodedClass: DecodedClass, + method: binary.Method + ): Seq[DecodedMethod.SuperConstructorArg] = + def matchSuperArg(liftedArg: LiftedTree[Nothing]): Boolean = + val primaryConstructor = liftedArg.owner.asClass.getAllOverloadedDecls(nme.Constructor).head + // a super arg takes the same parameters as its constructor + val sourceParams = extractSourceParams(method, primaryConstructor.declaredType) + val binaryParams = splitBinaryParams(method, sourceParams) + matchReturnType(liftedArg.tpe, method.returnType) && matchCapture(liftedArg, binaryParams.capturedParams) + collectLiftedTrees(decodedClass, method) { case arg: ConstructorArg => arg } + .collect { + case liftedArg if matchSuperArg(liftedArg) => + DecodedMethod.SuperConstructorArg( + decodedClass, + liftedArg.owner.asClass, + liftedArg.tree, + liftedArg.tpe.asInstanceOf + ) + } + + private def decodeLiftedTries(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] = + collectLiftedTrees(decodedClass, method) { case tree: LiftedTry => tree } + .collect { + case liftedTry if matchReturnType(liftedTry.tpe, method.returnType) => + wrapIfInline( + liftedTry, + DecodedMethod.LiftedTry(decodedClass, liftedTry.owner, liftedTry.tree, liftedTry.tpe.asInstanceOf) + ) + } + + private def decodeLocalLazyInit( + decodedClass: DecodedClass, + method: binary.Method, + names: Seq[String] + ): Seq[DecodedMethod] = + collectLocalMethods(decodedClass, method) { + case term if term.symbol.isModuleOrLazyVal && names.contains(term.symbol.nameStr) => + wrapIfInline(term, DecodedMethod.LazyInit(decodedClass, term.symbol)) + } + + private def matchTargetName(method: binary.Method, symbol: TermSymbol): Boolean = + method.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr) + + private def matchSignature( + method: binary.Method, + declaredType: TypeOrMethodic, + captureAllowed: Boolean = true, + asJavaVarargs: Boolean = false, + checkParamNames: Boolean = true, + checkTypeErasure: Boolean = true + ): Boolean = + val sourceParams = extractSourceParams(method, declaredType) + val binaryParams = splitBinaryParams(method, sourceParams) + + def matchParamNames: Boolean = + sourceParams.declaredParamNames + .corresponds(binaryParams.declaredParams)((name, javaParam) => name.toString == javaParam.name) + + def matchTypeErasure: Boolean = + sourceParams.regularParamTypes + .corresponds(binaryParams.regularParams)((tpe, javaParam) => + matchArgType(tpe, javaParam.`type`, asJavaVarargs) + ) && matchReturnType(sourceParams.returnType, binaryParams.returnType) + + (captureAllowed || binaryParams.capturedParams.isEmpty) && + binaryParams.capturedParams.forall(_.isGeneratedParam) && + binaryParams.expandedParams.forall(_.isGeneratedParam) && + sourceParams.regularParamTypes.size == binaryParams.regularParams.size && + (!checkParamNames || matchParamNames) && + (!checkTypeErasure || matchTypeErasure) + end matchSignature + + private def matchSetterArgType(scalaVarType: TypeOrMethodic, javaSetterParamType: binary.Type): Boolean = + scalaVarType match + case scalaVarType: Type => + scalaVarType.erasedAsArgType(asJavaVarargs = false).exists(matchType(_, javaSetterParamType)) + case _: MethodicType => false + + private def collectLocalMethods( + decodedClass: DecodedClass, + method: binary.Method + )( + matcher: PartialFunction[LiftedTree[TermSymbol], DecodedMethod] + ): Seq[DecodedMethod] = + collectLiftedTrees(decodedClass, method) { case term: LocalTermDef => term } + .collect(matcher) + + private def collectLiftedTrees[S](decodedClass: DecodedClass, method: binary.Method)( + matcher: PartialFunction[LiftedTree[?], LiftedTree[S]] + ): Seq[LiftedTree[S]] = + val owners = withCompanionIfExtendsAnyVal(decodedClass) + val sourceLines = + if owners.size == 2 && method.parameters.exists(p => p.name.matches("\\$this\\$\\d+")) then + // workaround of https://github.com/lampepfl/dotty/issues/18816 + method.sourceLines.map(_.last) + else method.sourceLines + owners.flatMap(collectLiftedTrees(_, sourceLines)(matcher)) + + private def reduceAmbiguity(syms: Seq[DecodedMethod]): Seq[DecodedMethod] = + if syms.size > 1 then + val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _))) + if reduced.size != 0 then reduced else syms + else syms + + private def wrapIfInline(liftedTree: LiftedTree[?], decodedMethod: DecodedMethod): DecodedMethod = + liftedTree match + case InlinedFromDef(underlying, inlineCall) => + DecodedMethod.InlinedMethod(wrapIfInline(underlying, decodedMethod), inlineCall.callTree) + case InlinedFromArg(underlying, params, inlineCall) => + DecodedMethod.InlinedMethodFromArg(wrapIfInline(underlying, decodedMethod), params, inlineCall) + case _ => decodedMethod + + private def matchLiftedFunSignature(method: binary.Method, tree: LiftedTree[TermSymbol]): Boolean = + val sourceParams = extractSourceParams(method, tree.tpe) + val binaryParams = splitBinaryParams(method, sourceParams) + + def matchParamNames: Boolean = + sourceParams.declaredParamNames + .corresponds(binaryParams.declaredParams)((name, javaParam) => name.toString == javaParam.name) + + def matchTypeErasure: Boolean = + sourceParams.regularParamTypes + .corresponds(binaryParams.regularParams)((tpe, javaParam) => matchArgType(tpe, javaParam.`type`, false)) && + matchReturnType(sourceParams.returnType, binaryParams.returnType) + + matchParamNames && matchTypeErasure && matchCapture(tree, binaryParams.capturedParams) + end matchLiftedFunSignature + + private def matchReturnType(scalaType: TermType, javaType: Option[binary.Type]): Boolean = + scalaType match + case scalaType: Type => javaType.forall(jt => scalaType.erasedAsReturnType.exists(matchType(_, jt))) + case _: MethodicType | _: PackageRef => false + + private def extractSourceParams(method: binary.Method, tpe: TermType): SourceParams = + val (expandedParamTypes, returnType) = + if method.isConstructor && method.declaringClass.isJavaLangEnum then + (List(defn.StringType, defn.IntType), tpe.returnType) + else if !method.isAnonFun then tpe.returnType.expandContextFunctions + else (List.empty, tpe.returnType) + SourceParams(tpe.allParamNames, tpe.allParamTypes, expandedParamTypes, returnType) + + /* After code generation, a method ends up with more than its declared parameters. + * + * It has, in order: + * - captured params, + * - declared params, + * - "expanded" params (from java.lang.Enum constructors and uncurried context function types). + * + * We can only check the names of declared params. + * We can check the (erased) type of declared and expanded params; together we call them "regular" params. + */ + private def splitBinaryParams(method: binary.Method, sourceParams: SourceParams): BinaryParams = + val (capturedParams, regularParams) = + method.parameters.splitAt(method.parameters.size - sourceParams.regularParamTypes.size) + val (declaredParams, expandedParams) = regularParams.splitAt(sourceParams.declaredParamTypes.size) + BinaryParams(capturedParams, declaredParams, expandedParams, method.returnType) + + private def enclose(enclosing: DecodedMethod, enclosed: DecodedMethod): Boolean = + (enclosing, enclosed) match + case (enclosing: DecodedMethod.InlinedMethod, enclosed: DecodedMethod.InlinedMethod) => + enclosing.callPos.enclose(enclosed.callPos) || ( + !enclosed.callPos.enclose(enclosing.callPos) && + enclose(enclosing.underlying, enclosed.underlying) + ) + case (enclosing: DecodedMethod.InlinedMethod, enclosed) => + enclosing.callPos.enclose(enclosed.pos) + case (enclosing, enclosed: DecodedMethod.InlinedMethod) => + enclosing.pos.enclose(enclosed.callPos) + case (enclosing, enclosed) => + enclosing.pos.enclose(enclosed.pos) + + private case class SourceParams( + declaredParamNames: Seq[UnsignedTermName], + declaredParamTypes: Seq[Type], + expandedParamTypes: Seq[Type], + returnType: Type + ): + def regularParamTypes: Seq[Type] = declaredParamTypes ++ expandedParamTypes + + private case class BinaryParams( + capturedParams: Seq[binary.Parameter], + declaredParams: Seq[binary.Parameter], + expandedParams: Seq[binary.Parameter], + returnType: Option[binary.Type] + ): + def regularParams = declaredParams ++ expandedParams + + private def matchCapture(liftedTree: LiftedTree[?], capturedParams: Seq[binary.Parameter]): Boolean = + val anonymousPattern = "\\$\\d+".r + val evidencePattern = "evidence\\$\\d+".r + def toPattern(variable: TermSymbol): Regex = + variable.nameStr match + case anonymousPattern() => "\\$\\d+\\$\\$\\d+".r + case evidencePattern() => "evidence\\$\\d+\\$\\d+".r + case _ => + val encoded = NameTransformer.encode(variable.nameStr) + s"${Regex.quote(encoded)}(\\$$tailLocal\\d+)?(\\$$lzy\\d+)?\\$$\\d+".r + val patterns = liftedTree.scope(scoper).capturedVariables.map(toPattern) + def isCapture(param: String) = + patterns.exists(_.unapplySeq(param).nonEmpty) + def isProxy(param: String) = "(.+)\\$proxy\\d+\\$\\d+".r.unapplySeq(param).nonEmpty + def isThisOrOuter(param: String) = "(.+_|\\$)(this|outer)\\$\\d+".r.unapplySeq(param).nonEmpty + def isLazy(param: String) = "(.+)\\$lzy\\d+\\$\\d+".r.unapplySeq(param).nonEmpty + capturedParams.forall(p => isProxy(p.name) || isCapture(p.name) || isThisOrOuter(p.name) || isLazy(p.name)) + + protected def matchArgType(scalaType: Type, javaType: binary.Type, asJavaVarargs: Boolean): Boolean = + scalaType.erasedAsArgType(asJavaVarargs).exists(matchType(_, javaType)) + + private lazy val scalaPrimitivesToJava: Map[ClassSymbol, String] = Map( + defn.BooleanClass -> "boolean", + defn.ByteClass -> "byte", + defn.CharClass -> "char", + defn.DoubleClass -> "double", + defn.FloatClass -> "float", + defn.IntClass -> "int", + defn.LongClass -> "long", + defn.ShortClass -> "short", + defn.UnitClass -> "void", + defn.NullClass -> "scala.runtime.Null$" + ) + + private def matchType( + scalaType: ErasedTypeRef, + javaType: binary.Type + ): Boolean = + def rec(scalaType: ErasedTypeRef, javaType: String): Boolean = + scalaType match + case ErasedTypeRef.ArrayTypeRef(base, dimensions) => + javaType.endsWith("[]" * dimensions) && + rec(base, javaType.dropRight(2 * dimensions)) + case ErasedTypeRef.ClassRef(scalaClass) => + scalaPrimitivesToJava.get(scalaClass) match + case Some(javaPrimitive) => javaPrimitive == javaType + case None => matchClassType(scalaClass, javaType, nested = false) + rec(scalaType, javaType.name) + + private lazy val dollarDigitsMaybeDollarAtEndRegex = "\\$\\d+\\$?$".r + + private def matchClassType(scalaClass: ClassSymbol, javaType: String, nested: Boolean): Boolean = + def encodedName(nested: Boolean): String = scalaClass.name match + case ObjectClassTypeName(underlying) if nested => NameTransformer.encode(underlying.toString()) + case name => NameTransformer.encode(name.toString()) + scalaClass.owner match + case owner: PackageSymbol => + javaType == owner.fullName.toString() + "." + encodedName(nested) + case owner: ClassSymbol => + val encodedName1 = encodedName(nested) + javaType.endsWith("$" + encodedName1) && + matchClassType(owner, javaType.dropRight(1 + encodedName1.length()), nested = true) + case owner: TermOrTypeSymbol => + dollarDigitsMaybeDollarAtEndRegex.findFirstIn(javaType).exists { suffix => + val prefix = javaType.stripSuffix(suffix) + val encodedName1 = encodedName(nested = true) + prefix.endsWith("$" + encodedName1) && { + val ownerJavaType = prefix.dropRight(1 + encodedName1.length()) + enclosingClassOwners(owner).exists(matchClassType(_, ownerJavaType, nested = true)) + } + } + + private def enclosingClassOwners(sym: TermOrTypeSymbol): List[ClassSymbol] = + sym.owner match + case owner: ClassSymbol => owner :: enclosingClassOwners(owner) + case owner: TermOrTypeSymbol => enclosingClassOwners(owner) + case owner: PackageSymbol => Nil diff --git a/src/main/scala/ch/epfl/scala/decoder/BinaryVariableDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/BinaryVariableDecoder.scala new file mode 100644 index 0000000..3165182 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/BinaryVariableDecoder.scala @@ -0,0 +1,250 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.internal.* +import tastyquery.Contexts.* +import tastyquery.SourceLanguage +import tastyquery.Symbols.* +import tastyquery.Types.* + +trait BinaryVariableDecoder(using Context, ThrowOrWarn): + self: BinaryClassDecoder & BinaryMethodDecoder => + + def decode(variable: binary.Variable, sourceLine: Int): DecodedVariable = + val decodedMethod = decode(variable.declaringMethod) + decode(decodedMethod, variable, sourceLine) + + def decode(decodedMethod: DecodedMethod, variable: binary.Variable, sourceLine: Int): DecodedVariable = + def tryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = + f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) + + extension (xs: Seq[DecodedVariable]) + def orTryDecode(f: PartialFunction[binary.Variable, Seq[DecodedVariable]]): Seq[DecodedVariable] = + if xs.nonEmpty then xs else f.applyOrElse(variable, _ => Seq.empty[DecodedVariable]) + val decodedVariables = tryDecode { + case Patterns.CapturedLzyVariable(name) => + if variable.declaringMethod.isConstructor then decodeClassCapture(decodedMethod, variable, name) + else decodeCapturedLzyVariable(decodedMethod, variable, name) + case Patterns.CapturedTailLocalVariable(name) => decodeMethodCapture(decodedMethod, variable, name) + case Patterns.AnyValCapture() => decodeAnyValCapture(decodedMethod) + case Patterns.CapturedProxy(nameWithProxy, name) => + decodeMethodCapture(decodedMethod, variable, name) + .orIfEmpty(decodeFromInlinedLambda(decodedMethod, variable, name)) + case Patterns.Capture(name) => + if variable.declaringMethod.isConstructor then decodeClassCapture(decodedMethod, variable, name) + else decodeMethodCapture(decodedMethod, variable, name) + case Patterns.This() => decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) + case Patterns.DollarThis() => decodeDollarThis(decodedMethod) + case Patterns.Proxy(name) => decodeProxy(decodedMethod, name) + case Patterns.InlinedThis() => decodeInlinedThis(decodedMethod, variable) + case Patterns.Outer() => decodeOuterParam(decodedMethod) + }.orTryDecode { case _ => + decodedMethod match + case decodedMethod: DecodedMethod.SAMOrPartialFunctionImpl => + decodeValDef(decodedMethod, variable, sourceLine) + .orIfEmpty(decodeSAMOrPartialFun(decodedMethod, decodedMethod.implementedSymbol, variable, sourceLine)) + case _: DecodedMethod.Bridge => ignore(variable, "Bridge method") + case _ => + if variable.isParameter then + decodeParameter(decodedMethod, variable) + .orIfEmpty( + decodeValDef(decodedMethod, variable, sourceLine) + ) // if the method returns a contextual function + else if variable.declaringMethod.isConstructor then + decodeValDef(decodedMethod, variable, sourceLine) + .orIfEmpty(decodeLocalValDefInConstructor(decodedMethod, variable, sourceLine)) + else decodeValDef(decodedMethod, variable, sourceLine) + } + decodedVariables.singleOrThrow(variable, decodedMethod) + + private def decodeCapturedLzyVariable( + decodedMethod: DecodedMethod, + variable: binary.Variable, + name: String + ): Seq[DecodedVariable] = + decodedMethod.base match + case m: DecodedMethod.LazyInit if m.symbol.nameStr == name => + Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) + case m: DecodedMethod.ValOrDefDef if m.symbol.nameStr == name => + Seq(DecodedVariable.CapturedVariable(decodedMethod, m.symbol)) + case _ => decodeMethodCapture(decodedMethod, variable, name).filter(_.symbol.isModuleOrLazyVal) + + private def decodeMethodCapture( + decodedMethod: DecodedMethod, + variable: binary.Variable, + name: String + ): Seq[DecodedVariable.CapturedVariable] = + for + scope <- getScope(decodedMethod).toSeq + sym <- scope.capturedVariables + if name == sym.nameStr && matchCaptureType(sym, variable.`type`) + yield DecodedVariable.CapturedVariable(decodedMethod, sym) + + private def decodeFromInlinedLambda( + decodedMethod: DecodedMethod, + variable: binary.Variable, + name: String + ): Seq[DecodedVariable.CapturedVariable] = + decodedMethod match + case m: DecodedMethod.InlinedMethodFromArg => + for + sym <- m.inlineCall.symbol.paramSymbols + if sym.nameStr == name + yield DecodedVariable.CapturedVariable(decodedMethod, sym) + case _ => Seq.empty + + private def getScope(decodedSym: DecodedSymbol): Option[Scope] = + decodedSym.symbolOpt + .map(scoper.getScope) + .orElse(decodedSym.treeOpt.map(scoper.getScope)) + .map: baseScope => + decodedSym match + case m: DecodedMethod.InlinedMethodFromArg => + scoper.inlinedFromLambdaArg(baseScope, m.lambdaParams, m.inlineCall) + case _ => baseScope + + private def decodeParameter(decodedMethod: DecodedMethod, variable: binary.Variable): Seq[DecodedVariable] = + (decodedMethod, variable) match + case (m: DecodedMethod.AdaptedFun, Patterns.V(i)) => + for + owner <- decodedMethod.symbolOpt.toSeq.collect { case sym: TermSymbol => sym } + if owner.paramSymbols.size > i - 1 + yield DecodedVariable.ValDef(m, owner.paramSymbols(i - 1)) + case (m: DecodedMethod.SetterAccessor, Patterns.XDollar(0)) if !m.symbol.isMethod => + Seq(DecodedVariable.SetterParam(m, m.symbol.declaredType.asInstanceOf[Type])) + case (m: DecodedMethod.SpecializedMethod, Patterns.XDollar(i)) => + if m.symbol.paramSymbols.size > i then Seq(DecodedVariable.SpecializedParam(m, m.symbol.paramSymbols(i))) + else Seq.empty + case _ => + for + owner <- decodedMethod.symbolOpt.toSeq.collect { case sym: TermSymbol => sym } + params <- owner.paramSymss.collect { case Left(value) => value } + sym <- params + if variable.name == sym.nameStr || matchEmptyName(variable, sym) + yield DecodedVariable.ValDef(decodedMethod, sym) + + private def matchEmptyName(variable: binary.Variable, sym: TermSymbol): Boolean = + val EmptyName = "(arg|x\\$)(\\d+)".r + (variable.name, sym.nameStr) match + case (EmptyName(_, x), EmptyName(_, y)) => x == y + case _ => false + + private def decodeValDef( + decodedMethod: DecodedMethod, + variable: binary.Variable, + sourceLine: Int + ): Seq[DecodedVariable] = + for + tree <- decodedMethod.treeOpt.toSeq + localVar <- VariableCollector.collectVariables(scoper, tree).toSeq + if variable.name == localVar.sym.nameStr && matchSourceLinesAndType(variable, localVar, sourceLine) + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeLocalValDefInConstructor( + decodedMethod: DecodedMethod, + variable: binary.Variable, + sourceLine: Int + ): Seq[DecodedVariable] = + for + tree <- decodedMethod.owner.treeOpt.toSeq + localVar <- VariableCollector.collectVariables(scoper, tree).toSeq + if variable.name == localVar.sym.nameStr && matchSourceLinesAndType(variable, localVar, sourceLine) + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeSAMOrPartialFun( + decodedMethod: DecodedMethod, + owner: TermSymbol, + variable: binary.Variable, + sourceLine: Int + ): Seq[DecodedVariable] = + if owner.sourceLanguage == SourceLanguage.Scala2 then + val x = + for + paramSym <- owner.paramSymss.collect { case Left(value) => value }.flatten + if variable.name == paramSym.nameStr + yield DecodedVariable.ValDef(decodedMethod, paramSym) + x.toSeq + else + for + localVar <- owner.tree.toSeq.flatMap(t => VariableCollector.collectVariables(scoper, t)) + if variable.name == localVar.sym.nameStr + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeProxy(decodedMethod: DecodedMethod, name: String): Seq[DecodedVariable] = + for + metTree <- decodedMethod.treeOpt.toSeq + localVar <- VariableCollector.collectVariables(scoper, metTree) + if name == localVar.sym.nameStr + yield DecodedVariable.ValDef(decodedMethod, localVar.sym.asTerm) + + private def decodeInlinedThis(decodedMethod: DecodedMethod, variable: binary.Variable): Seq[DecodedVariable] = + val decodedClassSym = variable.`type` match + case cls: binary.BinaryClass => decode(cls).classSymbol + case _ => None + for + metTree <- decodedMethod.treeOpt.toSeq + decodedClassSym <- decodedClassSym.toSeq + localVariables = VariableCollector.collectVariables(scoper, metTree, sym = decodedMethod.symbolOpt) + if localVariables.map(_.sym).filter(_.isClass).exists(decodedClassSym.linearization.contains) + yield DecodedVariable.This(decodedMethod, decodedClassSym.thisType) + + private def decodeOuterParam(decodedMethod: DecodedMethod): Seq[DecodedVariable] = + decodedMethod.owner.symbolOpt + .flatMap(_.outerClass) + .map(outerClass => DecodedVariable.OuterParam(decodedMethod, outerClass.selfType)) + .toSeq + + private def decodeClassCapture( + decodedMethod: DecodedMethod, + variable: binary.Variable, + name: String + ): Seq[DecodedVariable] = + getScope(decodedMethod.owner).toSeq + .flatMap(_.capturedVariables) + .filter(sym => name == sym.nameStr && matchCaptureType(sym, variable.`type`)) + .map(DecodedVariable.CapturedVariable(decodedMethod, _)) + + private def decodeDollarThis(decodedMethod: DecodedMethod): Seq[DecodedVariable] = + decodedMethod match + case _: DecodedMethod.TraitStaticForwarder => + decodedMethod.owner.thisType.toSeq.map(DecodedVariable.This(decodedMethod, _)) + case _ => + for + classSym <- decodedMethod.owner.companionClassSymbol.toSeq + if classSym.isSubClass(defn.AnyValClass) + sym <- classSym.declarations.collect { + case sym: TermSymbol if sym.isVal && !sym.isMethod => sym + } + yield DecodedVariable.AnyValThis(decodedMethod, sym) + + private def decodeAnyValCapture(decodedMethod: DecodedMethod): Seq[DecodedVariable] = + for + classSym <- decodedMethod.owner.companionClassSymbol.toSeq + if classSym.isSubClass(defn.AnyValClass) + sym <- classSym.declarations.collect { + case sym: TermSymbol if sym.isVal && !sym.isMethod => sym + } + yield DecodedVariable.CapturedVariable(decodedMethod, sym) + + private def matchCaptureType(sym: TermSymbol, binaryTpe: binary.Type): Boolean = + if sym.isModuleOrLazyVal then binaryTpe.isLazy + else if sym.isVar then binaryTpe.isRef + else matchArgType(sym.declaredType.requireType, binaryTpe, false) + + private def matchSourceLinesAndType(variable: binary.Variable, localVar: LocalVariable, sourceLine: Int): Boolean = + matchSourceLines(variable, localVar, sourceLine) && matchType(variable.`type`, localVar) + + private def matchSourceLines(variable: binary.Variable, localVar: LocalVariable, sourceLine: Int): Boolean = + val sourceName = variable.sourceName.getOrElse("") + // we use endsWith instead of == because of tasty-query#434 + val positions = localVar.positions.filter(pos => pos.sourceFile.name.endsWith(sourceName)) + variable.declaringMethod.isConstructor || + positions.isEmpty || // can happen when localVar has been inlined from a transparent inline + positions.exists(_.containsLine(sourceLine - 1)) + + private def matchType(binaryTpe: binary.Type, localVar: LocalVariable): Boolean = + localVar.sym match + case sym: ClassSymbol => matchArgType(localVar.tpe, binaryTpe, false) + case sym: TermSymbol => + if sym.isModuleOrLazyVal then binaryTpe.isLazy + else (sym.isVar && binaryTpe.isRef) || matchArgType(localVar.tpe, binaryTpe, false) + case _ => false diff --git a/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala b/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala index 2218188..647f6b9 100644 --- a/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala +++ b/src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala @@ -3,8 +3,9 @@ package ch.epfl.scala.decoder import ch.epfl.scala.decoder.internal.showBasic import tastyquery.SourcePosition import tastyquery.Symbols.* -import tastyquery.Trees.Tree +import tastyquery.Trees.* import tastyquery.Types.* +import ch.epfl.scala.decoder.internal.InlineCall sealed trait DecodedSymbol: def symbolOpt: Option[ClassSymbol | TermSymbol] = None @@ -35,6 +36,7 @@ object DecodedClass: else s"$underlying (inlined)" sealed trait DecodedMethod extends DecodedSymbol: + def base: DecodedMethod = this def owner: DecodedClass override def symbolOpt: Option[TermSymbol] = None def declaredType: TypeOrMethodic @@ -154,14 +156,30 @@ object DecodedMethod: override def toString: String = s"AdaptedFun($owner, ${declaredType.showBasic})" final class InlinedMethod(val underlying: DecodedMethod, val callTree: Tree) extends DecodedMethod: + override def base: DecodedMethod = underlying.base override def owner: DecodedClass = underlying.owner override def declaredType: TypeOrMethodic = underlying.declaredType override def symbolOpt: Option[TermSymbol] = underlying.symbolOpt + override def treeOpt: Option[Tree] = underlying.treeOpt def callPos: SourcePosition = callTree.pos override def toString: String = if underlying.isInstanceOf[InlinedMethod] then underlying.toString else s"$underlying (inlined)" + final class InlinedMethodFromArg( + val underlying: DecodedMethod, + val lambdaParams: Seq[TermSymbol], + val inlineCall: InlineCall + ) extends DecodedMethod: + override def base: DecodedMethod = underlying.base + override def owner: DecodedClass = underlying.owner + override def declaredType: TypeOrMethodic = underlying.declaredType + override def symbolOpt: Option[TermSymbol] = underlying.symbolOpt + override def treeOpt: Option[Tree] = underlying.treeOpt + override def toString: String = + if underlying.isInstanceOf[InlinedMethodFromArg] then underlying.toString + else s"$underlying (inlined from arg)" + sealed trait DecodedField extends DecodedSymbol: def owner: DecodedClass override def symbolOpt: Option[TermSymbol] = None @@ -203,7 +221,7 @@ object DecodedVariable: final class ValDef(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: def declaredType: TypeOrMethodic = symbol.declaredType override def symbolOpt: Option[TermSymbol] = Some(symbol) - override def toString: String = s"LocalVariable($owner, ${symbol.showBasic})" + override def toString: String = s"ValDef($owner, ${symbol.showBasic})" final class CapturedVariable(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: def declaredType: TypeOrMethodic = symbol.declaredType @@ -213,7 +231,19 @@ object DecodedVariable: final class This(val owner: DecodedMethod, val declaredType: Type) extends DecodedVariable: override def toString: String = s"This($owner, ${declaredType.showBasic})" + final class OuterParam(val owner: DecodedMethod, val declaredType: Type) extends DecodedVariable: + override def toString: String = s"OuterParam($owner, ${declaredType.showBasic})" + final class AnyValThis(val owner: DecodedMethod, val symbol: TermSymbol) extends DecodedVariable: def declaredType: TypeOrMethodic = symbol.declaredType override def symbolOpt: Option[TermSymbol] = Some(symbol) override def toString: String = s"AnyValThis($owner, ${declaredType.showBasic})" + + final class SetterParam(val owner: DecodedMethod.SetterAccessor, val declaredType: Type) extends DecodedVariable: + override def toString: String = s"SetterParam($owner, ${declaredType.showBasic})" + + final class SpecializedParam(val owner: DecodedMethod.SpecializedMethod, val symbol: TermSymbol) + extends DecodedVariable: + def declaredType: TypeOrMethodic = symbol.declaredType + override def symbolOpt: Option[TermSymbol] = Some(symbol) + override def toString: String = s"SpecializedParam($owner, ${symbol.showBasic})" diff --git a/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala b/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala index 02aba7b..281c7b0 100644 --- a/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala +++ b/src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala @@ -14,8 +14,6 @@ class StackTraceFormatter(using ThrowOrWarn): val typeAscription = variable.declaredType match case tpe: Type => ": " + format(tpe) case tpe => format(tpe) - val test1 = formatOwner(variable) - val test2 = formatName(variable) formatName(variable) + typeAscription def formatMethodSignatures(input: String): String = { @@ -86,6 +84,7 @@ class StackTraceFormatter(using ThrowOrWarn): case method: DecodedMethod.AdaptedFun => formatOwner(method.target) case method: DecodedMethod.SAMOrPartialFunctionConstructor => format(method.owner) case method: DecodedMethod.InlinedMethod => formatOwner(method.underlying) + case method: DecodedMethod.InlinedMethodFromArg => formatOwner(method.underlying) private def formatOwner(field: DecodedField): String = format(field.owner) @@ -109,6 +108,9 @@ class StackTraceFormatter(using ThrowOrWarn): case variable: DecodedVariable.CapturedVariable => formatName(variable.symbol).dot("") case variable: DecodedVariable.This => "this" case variable: DecodedVariable.AnyValThis => formatName(variable.symbol) + case variable: DecodedVariable.OuterParam => "" + case variable: DecodedVariable.SetterParam => "x$0" + case variable: DecodedVariable.SpecializedParam => formatName(variable.symbol).dot("") private def formatName(method: DecodedMethod): String = method match @@ -137,6 +139,7 @@ class StackTraceFormatter(using ThrowOrWarn): case method: DecodedMethod.AdaptedFun => formatName(method.target).dot("") case _: DecodedMethod.SAMOrPartialFunctionConstructor => "" case method: DecodedMethod.InlinedMethod => formatName(method.underlying) + case method: DecodedMethod.InlinedMethodFromArg => formatName(method.underlying) private def formatOwner(sym: Symbol): String = formatAsOwner(sym.owner) @@ -179,6 +182,7 @@ class StackTraceFormatter(using ThrowOrWarn): private def format(name: Name): String = def rec(name: Name): String = name match case DefaultGetterName(termName, num) => s"${termName.toString()}." + case SimpleTypeName("") => "Object" case name: TypeName => rec(name.toTermName) case SimpleName("$anonfun") => "" case SimpleName("$anon") => "" diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClass.scala b/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClass.scala new file mode 100644 index 0000000..37e29c5 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClass.scala @@ -0,0 +1,18 @@ +package ch.epfl.scala.decoder.binary + +trait BinaryClass extends Type: + def name: String + def isInterface: Boolean + def superclass: Option[BinaryClass] + def interfaces: Seq[BinaryClass] + def method(name: String, descriptor: String): Option[Method] + def declaredField(name: String): Option[Field] + def declaredMethod(name: String, descriptor: String): Option[Method] + def declaredMethods: Seq[Method] + def declaredFields: Seq[Field] + def classLoader: BinaryClassLoader + + private[decoder] def isObject = name.endsWith("$") + private[decoder] def isPackageObject = name.endsWith(".package$") || name.endsWith("$package$") + private[decoder] def isPartialFunction = superclass.exists(_.name == "scala.runtime.AbstractPartialFunction") + private[decoder] def isJavaLangEnum = superclass.exists(_.name == "java.lang.Enum") diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClassLoader.scala b/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClassLoader.scala index 06bd519..78538d6 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClassLoader.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/BinaryClassLoader.scala @@ -1,4 +1,4 @@ package ch.epfl.scala.decoder.binary trait BinaryClassLoader: - def loadClass(name: String): ClassType + def loadClass(name: String): BinaryClass diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/ClassType.scala b/src/main/scala/ch/epfl/scala/decoder/binary/ClassType.scala deleted file mode 100644 index 92ab646..0000000 --- a/src/main/scala/ch/epfl/scala/decoder/binary/ClassType.scala +++ /dev/null @@ -1,18 +0,0 @@ -package ch.epfl.scala.decoder.binary - -trait ClassType extends Type: - def name: String - def isInterface: Boolean - def superclass: Option[ClassType] - def interfaces: Seq[ClassType] - def method(name: String, descriptor: String): Option[Method] - def declaredField(name: String): Option[Field] - def declaredMethod(name: String, descriptor: String): Option[Method] - def declaredMethods: Seq[Method] - def declaredFields: Seq[Field] - def classLoader: BinaryClassLoader - - def isObject = name.endsWith("$") - def isPackageObject = name.endsWith(".package$") || name.endsWith("$package$") - def isPartialFunction = superclass.exists(_.name == "scala.runtime.AbstractPartialFunction") - def isJavaLangEnum = superclass.exists(_.name == "java.lang.Enum") diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Field.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Field.scala index 53d1b36..9731dad 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Field.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Field.scala @@ -1,6 +1,6 @@ package ch.epfl.scala.decoder.binary trait Field extends Symbol: - def declaringClass: ClassType + def declaringClass: BinaryClass def `type`: Type def isStatic: Boolean diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala index 196583c..e7cdd4c 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Method.scala @@ -1,27 +1,28 @@ package ch.epfl.scala.decoder.binary trait Method extends Symbol: - def declaringClass: ClassType - def allParameters: Seq[Parameter] - def variables: Seq[Variable] + def declaringClass: BinaryClass + def signedName: SignedName + def parameters: Seq[Parameter] // return None if the class of the return type is not yet loaded def returnType: Option[Type] - def returnTypeName: String - def isBridge: Boolean + def variables: Seq[Variable] + def isConstructor: Boolean def isStatic: Boolean def isFinal: Boolean + def isBridge: Boolean def instructions: Seq[Instruction] - def isConstructor: Boolean - def signedName: SignedName - def isExtensionMethod: Boolean = name.endsWith("$extension") && !isStatic && !isBridge - def isTraitStaticForwarder: Boolean = + def name: String = signedName.name + + private[decoder] def isExtensionMethod: Boolean = name.endsWith("$extension") && !isStatic && !isBridge + private[decoder] def isTraitStaticForwarder: Boolean = declaringClass.isInterface && isStatic && name.endsWith("$") && !isDeserializeLambda && !isTraitInitializer - def isTraitInitializer: Boolean = name == "$init$" && isStatic - def isClassInitializer: Boolean = name == "" - def isPartialFunctionApplyOrElse: Boolean = declaringClass.isPartialFunction && name == "applyOrElse" - def isDeserializeLambda: Boolean = + private[decoder] def isTraitInitializer: Boolean = name == "$init$" && isStatic + private[decoder] def isClassInitializer: Boolean = name == "" + private[decoder] def isPartialFunctionApplyOrElse: Boolean = declaringClass.isPartialFunction && name == "applyOrElse" + private[decoder] def isDeserializeLambda: Boolean = isStatic && name == "$deserializeLambda$" && - allParameters.map(_.`type`.name) == Seq("java.lang.invoke.SerializedLambda") - def isAnonFun: Boolean = name.matches("(.*)\\$anonfun\\$\\d+") + parameters.map(_.`type`.name) == Seq("java.lang.invoke.SerializedLambda") + private[decoder] def isAnonFun: Boolean = name.matches("(.*)\\$anonfun\\$\\d+") diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Parameter.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Parameter.scala index 38be1ac..6e81aa9 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Parameter.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Parameter.scala @@ -1,11 +1,12 @@ package ch.epfl.scala.decoder.binary -trait Parameter extends Symbol: +trait Parameter extends Variable: def `type`: Type - def isThis: Boolean = name == "$this" - def isOuter: Boolean = name == "$outer" - def isCapture: Boolean = !name.matches("_\\$\\d+") && name.matches(".+\\$\\d+") - def isUnknownJavaArg: Boolean = name.matches("arg\\d+") - def isJavaLangEnumParam: Boolean = name == "_$name" || name == "_$ordinal" - def isGenerated: Boolean = isCapture || isOuter || isThis || isUnknownJavaArg || isJavaLangEnumParam + private[decoder] def isThisParam: Boolean = name == "$this" + private[decoder] def isOuterParam: Boolean = name == "$outer" + private[decoder] def isCapturedParam: Boolean = !name.matches("_\\$\\d+") && name.matches(".+\\$\\d+") + private[decoder] def isUnknownJavaParam: Boolean = name.matches("arg\\d+") + private[decoder] def isJavaLangEnumParam: Boolean = name == "_$name" || name == "_$ordinal" + private[decoder] def isGeneratedParam: Boolean = + isCapturedParam || isOuterParam || isThisParam || isUnknownJavaParam || isJavaLangEnumParam diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/SourceLines.scala b/src/main/scala/ch/epfl/scala/decoder/binary/SourceLines.scala index 802c801..c589e01 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/SourceLines.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/SourceLines.scala @@ -5,19 +5,14 @@ final case class SourceLines(sourceName: String, lines: Seq[Int]): if lines.size > 2 then Seq(lines.head, lines.last) else lines - def tastyLines = lines.map(_ - 1) - - def tastySpan: Seq[Int] = - span.map(_ - 1) - def showSpan: String = span.mkString("(", ", ", ")") - def filterTasty(f: Int => Boolean): SourceLines = copy(lines = lines.filter(l => f(l - 1))) - def last: SourceLines = copy(lines = lines.lastOption.toSeq) - def isEmpty: Boolean = lines.isEmpty + private[decoder] def tastyLines = lines.map(_ - 1) + private[decoder] def tastySpan: Seq[Int] = span.map(_ - 1) + object SourceLines: def apply(sourceName: String, lines: Seq[Int]): SourceLines = new SourceLines(sourceName, lines.distinct.sorted) diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala index 83b8366..f5b03b9 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Symbol.scala @@ -5,5 +5,4 @@ trait Symbol: def sourceLines: Option[SourceLines] def sourceName: Option[String] = sourceLines.map(_.sourceName) - def showSpan: String = - sourceLines.map(_.showSpan).getOrElse("") + def showSpan: String = sourceLines.map(_.showSpan).getOrElse("") diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Type.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Type.scala index bb5766e..9eb0a28 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Type.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Type.scala @@ -1,3 +1,7 @@ package ch.epfl.scala.decoder.binary -trait Type extends Symbol +trait Type extends Symbol: + private[decoder] def isLazy: Boolean = + "scala\\.runtime\\.Lazy(Boolean|Byte|Char|Short|Int|Long|Float|Double|Unit|Ref)".r.matches(name) + private[decoder] def isRef: Boolean = + "scala\\.runtime\\.(Volatile)?(Boolean|Byte|Char|Short|Int|Long|Float|Double|Object)Ref".r.matches(name) diff --git a/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala b/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala index 28310c0..cc75c00 100644 --- a/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala +++ b/src/main/scala/ch/epfl/scala/decoder/binary/Variable.scala @@ -3,3 +3,4 @@ package ch.epfl.scala.decoder.binary trait Variable extends Symbol: def `type`: Type def declaringMethod: Method + def isParameter: Boolean = isInstanceOf[Parameter] diff --git a/src/main/scala/ch/epfl/scala/decoder/exceptions.scala b/src/main/scala/ch/epfl/scala/decoder/exceptions.scala index 0653c65..2eb7e74 100644 --- a/src/main/scala/ch/epfl/scala/decoder/exceptions.scala +++ b/src/main/scala/ch/epfl/scala/decoder/exceptions.scala @@ -13,12 +13,12 @@ case class IgnoredException(symbol: binary.Symbol, reason: String) case class UnexpectedException(message: String) extends Exception(message) -def notFound(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol] = None): Nothing = - throw NotFoundException(symbol, decodedOwner) +inline def notFound(symbol: binary.Symbol, decodedOwner: Option[DecodedSymbol] = None): Nothing = + throw new NotFoundException(symbol, decodedOwner) -def ambiguous(symbol: binary.Symbol, candidates: Seq[DecodedSymbol]): Nothing = - throw AmbiguousException(symbol, candidates) +inline def ambiguous(symbol: binary.Symbol, candidates: Seq[DecodedSymbol]): Nothing = + throw new AmbiguousException(symbol, candidates) -def ignore(symbol: binary.Symbol, reason: String): Nothing = throw IgnoredException(symbol, reason) +inline def ignore(symbol: binary.Symbol, reason: String): Nothing = throw new IgnoredException(symbol, reason) -def unexpected(message: String): Nothing = throw UnexpectedException(message) +inline def unexpected(message: String): Nothing = throw new UnexpectedException(message) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/CachedBinaryDecoder.scala b/src/main/scala/ch/epfl/scala/decoder/internal/CachedBinaryDecoder.scala deleted file mode 100644 index c681d80..0000000 --- a/src/main/scala/ch/epfl/scala/decoder/internal/CachedBinaryDecoder.scala +++ /dev/null @@ -1,22 +0,0 @@ -package ch.epfl.scala.decoder.internal - -import ch.epfl.scala.decoder.binary -import tastyquery.Contexts.* -import tastyquery.Symbols.* -import ch.epfl.scala.decoder.* - -import scala.collection.concurrent.TrieMap - -class CachedBinaryDecoder(using Context, ThrowOrWarn) extends BinaryDecoder: - private val classCache: TrieMap[String, DecodedClass] = TrieMap.empty - private val methodCache: TrieMap[(String, binary.SignedName), DecodedMethod] = TrieMap.empty - private val liftedTreesCache: TrieMap[Symbol, Seq[LiftedTree[?]]] = TrieMap.empty - - override def decode(cls: binary.ClassType): DecodedClass = - classCache.getOrElseUpdate(cls.name, super.decode(cls)) - - override def decode(method: binary.Method): DecodedMethod = - methodCache.getOrElseUpdate((method.declaringClass.name, method.signedName), super.decode(method)) - - override protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = - liftedTreesCache.getOrElseUpdate(owner, super.collectAllLiftedTrees(owner)) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala b/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala deleted file mode 100644 index c2f6c6e..0000000 --- a/src/main/scala/ch/epfl/scala/decoder/internal/CaptureCollector.scala +++ /dev/null @@ -1,43 +0,0 @@ -package ch.epfl.scala.decoder.internal - -import tastyquery.Trees.* -import scala.collection.mutable -import tastyquery.Symbols.* -import tastyquery.Traversers.* -import tastyquery.Contexts.* -import tastyquery.SourcePosition -import tastyquery.Types.* -import tastyquery.Traversers -import ch.epfl.scala.decoder.ThrowOrWarn -import scala.languageFeature.postfixOps - -object CaptureCollector: - def collectCaptures(tree: Tree)(using Context, ThrowOrWarn): Set[TermSymbol] = - val collector = CaptureCollector() - collector.traverse(tree) - collector.capture.toSet - -class CaptureCollector(using Context, ThrowOrWarn) extends TreeTraverser: - private val capture: mutable.Set[TermSymbol] = mutable.Set.empty - private val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty - - def loopCollect(symbol: Symbol)(collect: => Unit): Unit = - if !alreadySeen.contains(symbol) then - alreadySeen += symbol - collect - override def traverse(tree: Tree): Unit = - tree match - case _: TypeTree => () - case valDef: ValDef => - alreadySeen += valDef.symbol - traverse(valDef.rhs) - case bind: Bind => - alreadySeen += bind.symbol - traverse(bind.body) - case ident: Ident => - for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do - if !alreadySeen.contains(sym) && sym.isLocal then - if !sym.isMethod then capture += sym - if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse)) - else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse)) - case _ => super.traverse(tree) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/Definitions.scala b/src/main/scala/ch/epfl/scala/decoder/internal/Definitions.scala index 21abf79..37a8ad3 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/Definitions.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/Definitions.scala @@ -1,23 +1,26 @@ package ch.epfl.scala.decoder.internal -import tastyquery.Contexts.Context +import tastyquery.Contexts.* import tastyquery.Names.* import tastyquery.Types.* -class Definitions(using ctx: Context): - export ctx.defn.* +object Definitions: + def scalaRuntimePackage(using Context) = defn.scalaPackage.getPackageDecl(SimpleName("runtime")).get + def scalaAnnotationPackage(using Context) = defn.scalaPackage.getPackageDecl(SimpleName("annotation")).get + def javaPackage(using Context) = defn.RootPackage.getPackageDecl(SimpleName("java")).get + def javaIoPackage(using Context) = javaPackage.getPackageDecl(SimpleName("io")).get + def javaLangInvokePackage(using Context) = defn.javaLangPackage.getPackageDecl(SimpleName("invoke")).get - val scalaRuntimePackage = scalaPackage.getPackageDecl(SimpleName("runtime")).get - val javaPackage = RootPackage.getPackageDecl(SimpleName("java")).get - val javaIoPackage = javaPackage.getPackageDecl(SimpleName("io")).get - val javaLangInvokePackage = javaLangPackage.getPackageDecl(SimpleName("invoke")).get + def PartialFunctionClass(using Context) = defn.scalaPackage.getDecl(typeName("PartialFunction")).get.asClass + def AbstractPartialFunctionClass(using Context) = + scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass + def threadUnsafeClass(using Context) = scalaAnnotationPackage.getDecl(typeName("threadUnsafe")).get.asClass + def SerializableClass(using Context) = javaIoPackage.getDecl(typeName("Serializable")).get.asClass + def javaLangEnumClass(using Context) = defn.javaLangPackage.getDecl(typeName("Enum")).get.asClass - val PartialFunctionClass = scalaPackage.getDecl(typeName("PartialFunction")).get.asClass - val AbstractPartialFunctionClass = scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass - val SerializableClass = javaIoPackage.getDecl(typeName("Serializable")).get.asClass - val javaLangEnumClass = javaLangPackage.getDecl(typeName("Enum")).get.asClass + def SerializedLambdaType(using Context): Type = + TypeRef(javaLangInvokePackage.packageRef, typeName("SerializedLambda")) + def DeserializeLambdaType(using Context) = + MethodType(List(SimpleName("arg0")), List(SerializedLambdaType), defn.ObjectType) - val SerializedLambdaType: Type = TypeRef(javaLangInvokePackage.packageRef, typeName("SerializedLambda")) - val DeserializeLambdaType = MethodType(List(SimpleName("arg0")), List(SerializedLambdaType), ObjectType) - - val Function0Type = TypeRef(scalaPackage.packageRef, typeName("Function0")) + def Function0Type(using Context) = TypeRef(defn.scalaPackage.packageRef, typeName("Function0")) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/InlineCall.scala b/src/main/scala/ch/epfl/scala/decoder/internal/InlineCall.scala index 49636a6..207cb4b 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/InlineCall.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/InlineCall.scala @@ -4,6 +4,7 @@ import tastyquery.Trees.* import tastyquery.Symbols.* import tastyquery.Types.* import tastyquery.Contexts.* +import tastyquery.SourcePosition import tastyquery.decoder.Substituters import ch.epfl.scala.decoder.ThrowOrWarn @@ -15,9 +16,14 @@ case class InlineCall private ( ): def symbol(using Context): TermSymbol = termRefTree.symbol.asTerm + def pos: SourcePosition = callTree.pos + def substTypeParams(tpe: TermType)(using Context): TermType = Substituters.substLocalTypeParams(tpe, symbol.typeParamSymbols, typeArgs) + def substTypeParams(tpe: Type)(using Context): Type = + Substituters.substLocalTypeParams(tpe, symbol.typeParamSymbols, typeArgs) + def paramsMap(using Context): Map[TermSymbol, TermTree] = symbol.paramSymbols.zip(args).toMap @@ -28,7 +34,7 @@ object InlineCall: def unapply(fullTree: Tree)(using Context, ThrowOrWarn): Option[InlineCall] = def rec(tree: Tree, typeArgsAcc: List[Type], argsAcc: Seq[TermTree]): Option[InlineCall] = tree match - case termTree: TermReferenceTree if termTree.safeSymbol.exists(sym => sym.isInline && sym.asTerm.isMethod) => + case termTree: TermReferenceTree if termTree.safeTermSymbol.exists(sym => sym.isInline && sym.isMethod) => Some(InlineCall(termTree, typeArgsAcc, argsAcc, fullTree)) case Apply(fun, args) => rec(fun, typeArgsAcc, args ++ argsAcc) case TypeApply(fun, typeArgs) => rec(fun, typeArgs.map(_.toType) ++ typeArgsAcc, argsAcc) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTree.scala b/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTree.scala index cd772f8..71f73bd 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTree.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTree.scala @@ -1,13 +1,14 @@ package ch.epfl.scala.decoder.internal -import tastyquery.Trees.* +import ch.epfl.scala.decoder.* +import tastyquery.Contexts.* import tastyquery.SourcePosition import tastyquery.Symbols.* -import tastyquery.Contexts.* -import tastyquery.Types.* import tastyquery.Traversers.* +import tastyquery.Trees.* +import tastyquery.Types.* + import scala.collection.mutable -import ch.epfl.scala.decoder.* sealed trait LiftedTree[S]: def tree: Tree @@ -15,22 +16,17 @@ sealed trait LiftedTree[S]: def tpe: TermType def owner: Symbol - def inlinedFrom: List[InlineCall] = Nil - def inlinedArgs: Map[Symbol, Seq[TermTree]] = Map.empty - def positions(using Context, ThrowOrWarn): Seq[SourcePosition] = LiftedTree.collectPositions(this) - def capture(using Context, ThrowOrWarn): Seq[String] = LiftedTree.collectCapture(this) + def scope(scoper: Scoper): Scope = scoper.getScope(tree) end LiftedTree -sealed trait LocalTermDef extends LiftedTree[TermSymbol]: - def symbol: TermSymbol +sealed trait LocalTermDef(val symbol: TermSymbol) extends LiftedTree[TermSymbol]: def tpe: TypeOrMethodic = symbol.declaredType + override def scope(scoper: Scoper): Scope = scoper.getScope(symbol) -final case class LocalDef(tree: DefDef) extends LocalTermDef: - def symbol: TermSymbol = tree.symbol +final case class LocalDef(tree: DefDef) extends LocalTermDef(tree.symbol): def owner: Symbol = tree.symbol.owner -final case class LocalLazyVal(tree: ValDef) extends LocalTermDef: - def symbol: TermSymbol = tree.symbol +final case class LocalLazyVal(tree: ValDef) extends LocalTermDef(tree.symbol): def owner: Symbol = tree.symbol.owner final case class LambdaTree(lambda: Lambda)(using Context) extends LiftedTree[(TermSymbol, ClassSymbol)]: @@ -54,12 +50,11 @@ final case class ByNameArg(owner: Symbol, tree: TermTree, paramTpe: TermType, is def symbol: Nothing = unexpected("no symbol for by name arg") final case class ConstructorArg(owner: ClassSymbol, tree: TermTree, paramTpe: TermType)(using - ctx: Context, - defn: Definitions + ctx: Context ) extends LiftedTree[Nothing]: def tpe: TermType = paramTpe match - case _: ByNameType => defn.Function0Type.appliedTo(tree.tpe.asInstanceOf[Type]) + case _: ByNameType => Definitions.Function0Type.appliedTo(tree.tpe.asInstanceOf[Type]) case _ => tree.tpe def symbol: Nothing = unexpected("no symbol for constructor arg") @@ -70,109 +65,26 @@ final case class InlinedFromDef[S](underlying: LiftedTree[S], inlineCall: Inline def symbol: S = underlying.symbol def owner: Symbol = underlying.owner def tpe: TermType = inlineCall.substTypeParams(underlying.tpe) - override def inlinedFrom: List[InlineCall] = inlineCall :: underlying.inlinedFrom - override def inlinedArgs: Map[Symbol, Seq[TermTree]] = underlying.inlinedArgs + + override def scope(scoper: Scoper): Scope = + scoper.inlinedScope(underlying.scope(scoper), inlineCall) /** - * An inline call arg can capture a variable passed to another argument of the same call + * A lambda in an inline lambda can capture a val passed as argument to the inline call * Example: - * inline def withContext(ctx: Context)(f: Context ?=> T): T = f(ctx) + * inline def withContext(ctx: Context)(inline f: Context ?=> T): T = f(using ctx) * withContext(someCtx)(list.map()) * can capture someCtx + * + * @param params the params of the inline lambda + * @param inlineArgs the other args of the inline call */ -final case class InlinedFromArg[S](underlying: LiftedTree[S], params: Seq[TermSymbol], inlineArgs: Seq[TermTree]) +final case class InlinedFromArg[S](underlying: LiftedTree[S], lambdaParams: Seq[TermSymbol], inlineCall: InlineCall) extends LiftedTree[S]: def tree: Tree = underlying.tree def symbol: S = underlying.symbol def owner: Symbol = underlying.owner def tpe: TermType = underlying.tpe - override def inlinedFrom: List[InlineCall] = underlying.inlinedFrom - override def inlinedArgs: Map[Symbol, Seq[TermTree]] = underlying.inlinedArgs ++ params.map(_ -> inlineArgs) - -object LiftedTree: - // todo should also map the inlineArgs as a map Map[TermSymbol, TermTree] - private def collectPositions(liftedTree: LiftedTree[?])(using - Context, - ThrowOrWarn - ): Seq[SourcePosition] = - val positions = mutable.Set.empty[SourcePosition] - val alreadySeen = mutable.Set.empty[Symbol] - - def registerPosition(pos: SourcePosition): Unit = - if pos.isFullyDefined then positions += pos - - def loopCollect(symbol: Symbol)(collect: => Unit): Unit = - if !alreadySeen.contains(symbol) then - alreadySeen += symbol - collect - - class Traverser(inlinedFrom: List[InlineCall], inlinedArgs: Map[Symbol, Seq[TermTree]]) extends TreeTraverser: - private val inlineMapping: Map[Symbol, TermTree] = inlinedFrom.headOption.toSeq.flatMap(_.paramsMap).toMap - override def traverse(tree: Tree): Unit = - tree match - case _: TypeTree => () - case tree: TermReferenceTree => - for sym <- tree.safeSymbol do - for arg <- inlineMapping.get(sym) do - registerPosition(arg.pos) - loopCollect(sym)(Traverser(inlinedFrom.tail, inlinedArgs).traverse(arg)) - for args <- inlinedArgs.get(sym) do - val args = inlinedArgs(sym) - loopCollect(sym)(args.foreach(traverse)) - for tree <- sym.tree if sym.isInline do - registerPosition(tree.pos) - loopCollect(sym)(traverse(tree)) - case _ => () - super.traverse(tree) - - registerPosition(liftedTree.tree.pos) - Traverser(liftedTree.inlinedFrom, liftedTree.inlinedArgs).traverse(liftedTree.tree) - positions.toSeq - end collectPositions - - def collectCapture(liftedTree: LiftedTree[?])(using Context, ThrowOrWarn): Seq[String] = - val capture = mutable.Set.empty[String] - val alreadySeen = mutable.Set.empty[Symbol] - - def loopCollect(symbol: Symbol)(collect: => Unit): Unit = - if !alreadySeen.contains(symbol) then - alreadySeen += symbol - collect - - class Traverser(inlinedFrom: List[InlineCall], inlinedArgs: Map[Symbol, Seq[TermTree]])(using Context) - extends TreeTraverser: - private val inlineMapping: Map[Symbol, TermTree] = inlinedFrom.headOption.toSeq.flatMap(_.paramsMap).toMap - override def traverse(tree: Tree): Unit = - tree match - case tree: TermReferenceTree => - for symbol <- tree.safeSymbol do - for arg <- inlineMapping.get(symbol) do - loopCollect(symbol)(Traverser(inlinedFrom.tail, inlinedArgs).traverse(arg)) - for args <- inlinedArgs.get(symbol) do loopCollect(symbol)(args.foreach(traverse)) - case _ => () - - tree match - case _: TypeTree => () - case ident: Ident => - for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do - capture += sym.nameStr - if sym.isLocal then - if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse)) - else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse)) - case _ => super.traverse(tree) - end Traverser - - val traverser = Traverser(liftedTree.inlinedFrom, liftedTree.inlinedArgs) - def traverse(tree: LiftedTree[?]): Unit = - tree match - case term: LocalTermDef if term.symbol.isModuleVal => - loopCollect(term.symbol)(term.symbol.moduleClass.flatMap(_.tree).foreach(traverser.traverse)) - case term: LocalTermDef => - loopCollect(term.symbol)(traverser.traverse(term.tree)) - case lambda: LambdaTree => loopCollect(lambda.symbol(0))(lambda.tree) - case InlinedFromDef(underlying, inlineCall) => traverse(underlying) - case InlinedFromArg(underlying, params, inlineArgs) => traverse(underlying) - case tree => traverser.traverse(tree.tree) - traverse(liftedTree) - capture.toSeq - end collectCapture + + override def scope(scoper: Scoper): Scope = + scoper.inlinedFromLambdaArg(underlying.scope(scoper), lambdaParams, inlineCall) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTreeCollector.scala b/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTreeCollector.scala index dbe67b0..90b0b5c 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTreeCollector.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/LiftedTreeCollector.scala @@ -1,14 +1,13 @@ package ch.epfl.scala.decoder.internal -import tastyquery.Trees.* -import scala.collection.mutable +import ch.epfl.scala.decoder.ThrowOrWarn +import tastyquery.Contexts.* import tastyquery.Symbols.* import tastyquery.Traversers.* -import tastyquery.Contexts.* -import tastyquery.SourcePosition +import tastyquery.Trees.* import tastyquery.Types.* -import tastyquery.Traversers -import ch.epfl.scala.decoder.ThrowOrWarn + +import scala.collection.mutable /** * Collect all trees that could be lifted by the compiler: local defs, lambdas, try clauses, by-name applications @@ -16,11 +15,11 @@ import ch.epfl.scala.decoder.ThrowOrWarn * and compute the capture. */ object LiftedTreeCollector: - def collect(sym: Symbol)(using Context, Definitions, ThrowOrWarn): Seq[LiftedTree[?]] = + def collect(sym: Symbol)(using Context, ThrowOrWarn): Seq[LiftedTree[?]] = val collector = LiftedTreeCollector(sym) sym.tree.toSeq.flatMap(collector.collect) -class LiftedTreeCollector private (root: Symbol)(using Context, Definitions, ThrowOrWarn): +class LiftedTreeCollector private (root: Symbol)(using Context, ThrowOrWarn): private val inlinedTrees = mutable.Map.empty[TermSymbol, Seq[LiftedTree[?]]] private var owner = root @@ -62,10 +61,9 @@ class LiftedTreeCollector private (root: Symbol)(using Context, Definitions, Thr val liftedTrees = inlinedTrees.getOrElseUpdate(inlineCall.symbol, collectInlineDef(inlineCall.symbol)) buffer ++= liftedTrees.map(InlinedFromDef(_, inlineCall)) buffer ++= inlineCall.args.flatMap { arg => - extractLambda(arg) match + arg.asLambda match case Some(lambda) => - val params = lambda.meth.symbol.asTerm.paramSymbols - collect(arg).map(InlinedFromArg(_, params, inlineCall.args)) + collect(arg).map(InlinedFromArg(_, lambda.paramSymbols, inlineCall)) case None => collect(arg) } super.traverse(inlineCall.termRefTree) @@ -82,12 +80,6 @@ class LiftedTreeCollector private (root: Symbol)(using Context, Definitions, Thr inlinedTrees(symbol) = Seq.empty // break recursion symbol.tree.flatMap(extractRHS).toSeq.flatMap(collect) - private def extractLambda(tree: StatementTree): Option[Lambda] = - tree match - case lambda: Lambda => Some(lambda) - case block: Block => extractLambda(block.expr) - case _ => None - private def extractRHS(tree: DefTree): Option[TermTree] = tree match case tree: DefDef => tree.rhs diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/LocalVariable.scala b/src/main/scala/ch/epfl/scala/decoder/internal/LocalVariable.scala new file mode 100644 index 0000000..75075ab --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/internal/LocalVariable.scala @@ -0,0 +1,25 @@ +package ch.epfl.scala.decoder.internal + +import tastyquery.Types.Type +import tastyquery.SourcePosition +import tastyquery.Symbols.* +import tastyquery.Contexts.Context + +sealed trait LocalVariable: + def sym: Symbol + def scope: Scope + def tpe: Type + + def positions: Set[SourcePosition] = scope.allPositions + +object LocalVariable: + case class This(sym: ClassSymbol, scope: Scope) extends LocalVariable: + def tpe: Type = sym.thisType + + case class ValDef(sym: TermSymbol, scope: Scope) extends LocalVariable: + def tpe: Type = sym.declaredType.requireType + + case class InlinedFromDef(underlying: LocalVariable, inlineCall: InlineCall, scope: Scope)(using Context) + extends LocalVariable: + def sym: Symbol = underlying.sym + def tpe: Type = inlineCall.substTypeParams(underlying.tpe) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala b/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala index 3ee46bf..8cbac7d 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala @@ -5,7 +5,7 @@ import scala.util.matching.Regex object Patterns: object LocalClass: - def unapply(cls: binary.ClassType): Option[(String, String, Option[String])] = + def unapply(cls: binary.BinaryClass): Option[(String, String, Option[String])] = val decodedClassName = NameTransformer.decode(cls.name.split('.').last) unapply(decodedClassName) @@ -16,7 +16,7 @@ object Patterns: .map(xs => (xs(0), xs(1), Option(xs(2)).map(_.stripPrefix("$")).filter(_.nonEmpty))) object AnonClass: - def unapply(cls: binary.ClassType): Option[(String, Option[String])] = + def unapply(cls: binary.BinaryClass): Option[(String, Option[String])] = val decodedClassName = NameTransformer.decode(cls.name.split('.').last) unapply(decodedClassName) @@ -26,7 +26,7 @@ object Patterns: .map(xs => (xs(0), Option(xs(1)).map(_.stripPrefix("$")).filter(_.nonEmpty))) object InnerClass: - def unapply(cls: binary.ClassType): Option[String] = + def unapply(cls: binary.BinaryClass): Option[String] = val decodedClassName = NameTransformer.decode(cls.name.split('.').last) "(.+)\\$(.+)".r .unapplySeq(decodedClassName) @@ -50,6 +50,11 @@ object Patterns: (!method.isBridge && !method.isStatic) && "(.*)\\$\\$\\$outer".r.unapplySeq(NameTransformer.decode(method.name)).isDefined + def unapply(field: binary.Field): Boolean = field.name == "$outer" + + def unapply(variable: binary.Variable): Boolean = variable.name == "$outer" + end Outer + object AnonFun: def unapply(method: binary.Method): Boolean = !method.isBridge && "(.*)\\$anonfun\\$\\d+".r.unapplySeq(NameTransformer.decode(method.name)).isDefined @@ -73,7 +78,7 @@ object Patterns: object LocalLazyInit: def unapply(method: binary.Method): Option[Seq[String]] = - if method.isBridge || !method.allParameters.forall(_.isGenerated) then None + if method.isBridge || !method.parameters.forall(_.isGeneratedParam) then None else method.extractFromDecodedNames("""(.+)\$lzyINIT\d+\$(\d+)""".r)(_(0).stripSuffix("$")) object SuperArg: @@ -148,36 +153,47 @@ object Patterns: def unapply(field: binary.Field): Option[Int] = """OFFSET\$(?:_m_)?(\d+)""".r.unapplySeq(field.name).map(xs => xs(0).toInt) - object OuterField: - def unapply(field: binary.Field): Boolean = field.name == "$outer" - object SerialVersionUID: def unapply(field: binary.Field): Boolean = field.name == "serialVersionUID" object AnyValCapture: - def unapply(field: binary.Field): Boolean = - field.name.matches("\\$this\\$\\d+") + def unapply(field: binary.Field): Boolean = unapply(field.name) + def unapply(variable: binary.Variable): Boolean = unapply(variable.name) + private def unapply(name: String): Boolean = name.matches("\\$this\\$\\d+") object Capture: def unapply(field: binary.Field): Option[Seq[String]] = field.extractFromDecodedNames("(.+)\\$\\d+".r)(xs => xs(0)) - object LazyValBitmap: - def unapply(field: binary.Field): Option[String] = - "(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0)) + def unapply(variable: binary.Variable): Option[String] = + "(\\$\\d+)\\$\\$\\d+".r // anon variable + .unapplySeq(variable.name) + .orElse("(.+)\\$\\d+".r.unapplySeq(variable.name)) + .map(xs => xs(0)) + end Capture object CapturedLzyVariable: - def unapply(variable: binary.Variable): Option[String] = - "(.+)\\$lzy1\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0)) + def unapply(field: binary.Field): Option[Seq[String]] = + field.extractFromDecodedNames("(.+)\\$lzy\\d+\\$\\d+".r)(xs => xs(0)) - object CapturedVariable: def unapply(variable: binary.Variable): Option[String] = - "(.+)\\$\\d+".r.unapplySeq(variable.name).map(xs => xs(0)) + "(\\$\\d+)\\$\\$lzy\\d+\\$\\d+".r // anon variable + .unapplySeq(variable.name) + .orElse("(.+)\\$lzy\\d+\\$\\d+".r.unapplySeq(variable.name)) + .map(xs => xs(0)) + + object LazyValBitmap: + def unapply(field: binary.Field): Option[String] = + "(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0)) object CapturedTailLocalVariable: def unapply(variable: binary.Variable): Option[String] = "(.+)\\$tailLocal\\d+(\\$\\d+)?".r.unapplySeq(variable.name).map(xs => xs(0)) + object CapturedProxy: + def unapply(variable: binary.Variable): Option[(String, String)] = + "((.+)\\$proxy\\d+)\\$\\d+".r.unapplySeq(variable.name).map(xs => (xs(0), xs(1))) + object This: def unapply(variable: binary.Variable): Boolean = variable.name == "this" @@ -195,6 +211,14 @@ object Patterns: object InlinedThis: def unapply(variable: binary.Variable): Boolean = variable.name.endsWith("_this") + object XDollar: + def unapply(variable: binary.Variable): Option[Int] = + "x\\$(\\d+)".r.unapplySeq(variable.name).map(xs => xs(0).toInt) + + object V: + def unapply(variable: binary.Variable): Option[Int] = + "v(\\d+)".r.unapplySeq(variable.name).map(xs => xs(0).toInt) + extension (field: binary.Field) private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] = val extracted = field.unexpandedDecodedNames diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/Scope.scala b/src/main/scala/ch/epfl/scala/decoder/internal/Scope.scala new file mode 100644 index 0000000..19d6d84 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/internal/Scope.scala @@ -0,0 +1,23 @@ +package ch.epfl.scala.decoder.internal + +import tastyquery.SourcePosition +import tastyquery.Symbols.TermSymbol +import tastyquery.SourceFile +import tastyquery.Contexts.Context +import ch.epfl.scala.decoder.ThrowOrWarn + +/** + * The description of a scope in the code. + * + * @param position the main position of the scope + * @param inlinedPositions the other inlined positions in the scope + * @param inlinedParams the inlined params, that should later be replaced + */ +case class Scope( + position: SourcePosition, + inlinedPositions: Set[SourcePosition], + capturedVariables: Set[TermSymbol] +): + def allPositions: Set[SourcePosition] = Set(position) ++ inlinedPositions + def sourceFile: SourceFile = position.sourceFile + def inlineParams: Set[TermSymbol] = capturedVariables.filter(_.isParamInInlineMethod) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/Scoper.scala b/src/main/scala/ch/epfl/scala/decoder/internal/Scoper.scala new file mode 100644 index 0000000..88dfd37 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/internal/Scoper.scala @@ -0,0 +1,112 @@ +package ch.epfl.scala.decoder.internal + +import tastyquery.Contexts.Context +import ch.epfl.scala.decoder.ThrowOrWarn +import tastyquery.SourcePosition +import scala.collection.mutable +import tastyquery.Trees.* +import tastyquery.Symbols.* +import tastyquery.Traversers.TreeTraverser +import tastyquery.Modifiers.TermSymbolKind + +// computes and caches the scopes of symbols or trees +class Scoper(using Context, ThrowOrWarn): + private val cache = mutable.Map.empty[Symbol, LocalScope] + + private case class LocalScope( + position: SourcePosition, + capturedSyms: Set[TermSymbol], + inlineSyms: Set[TermSymbol] + ): + def capturedInlineParams: Set[TermSymbol] = capturedSyms.filter(_.isParamInInlineMethod) + def capturedMethods: Set[TermSymbol] = + capturedSyms.filter(s => s.isLocal && (s.isMethod || s.isModuleVal || s.isLazyVal)) + def capturedVariables: Set[TermSymbol] = capturedSyms.filter(s => !s.isMethod) + + /** + * Compute the scope inlined from an inline call: + * - compute the scope of the inlined arguments + * - use the pos of the inlineCall as main position + */ + def inlinedScope(scope: Scope, inlineCall: InlineCall): Scope = + val argScopes = + for + param <- scope.inlineParams + arg <- inlineCall.paramsMap.get(param).toSeq + yield getScope(arg) + + val (position, inlinedPositions) = + if scope.sourceFile == inlineCall.pos.sourceFile then + (scope.position, scope.inlinedPositions ++ argScopes.flatMap(_.allPositions)) + else (inlineCall.pos, scope.allPositions ++ argScopes.flatMap(_.inlinedPositions)) + val capturedVariables = scope.capturedVariables ++ argScopes.flatMap(_.capturedVariables) + Scope(position, inlinedPositions, capturedVariables) + + def inlinedFromLambdaArg(scope: Scope, lambdaParams: Seq[TermSymbol], inlineCall: InlineCall): Scope = + if lambdaParams.toSet.intersect(scope.capturedVariables).nonEmpty then + val argScopes = inlineCall.args.map(getScope) + val capturedVariables = scope.capturedVariables ++ argScopes.flatMap(_.capturedVariables) + scope.copy(capturedVariables = capturedVariables) + else scope + + def getScope(tree: Tree): Scope = buildScope(getLocalScope(tree)) + def getScope(sym: Symbol): Scope = buildScope(getLocalScope(sym)) + + private def buildScope(localScope: LocalScope): Scope = + def loopInline(acc: Map[TermSymbol, LocalScope]): Iterable[LocalScope] = + val remaining = acc.values.flatMap(_.inlineSyms).toSet.filter(!acc.contains(_)) + if remaining.isEmpty then acc.values + else loopInline(acc ++ getLocalScopes(remaining)) + def loopCapture(acc: Map[TermSymbol, LocalScope]): Iterable[LocalScope] = + val remaining = acc.values.flatMap(_.capturedMethods).toSet.filter(!acc.contains(_)) + if remaining.isEmpty then acc.values + else loopCapture(acc ++ getLocalScopes(remaining)) + val allInlined = loopInline(getLocalScopes(localScope.inlineSyms)) + val allCaptured = loopCapture(getLocalScopes(localScope.capturedMethods)) + Scope( + localScope.position, + allInlined.map(_.position).toSet, + localScope.capturedVariables ++ allCaptured.flatMap(_.capturedVariables) + ) + + private def getLocalScopes(syms: Set[TermSymbol]): Map[TermSymbol, LocalScope] = + syms.map(s => s -> getLocalScope(s)).toMap + + private def getLocalScope(sym: Symbol): LocalScope = + sym match + case sym: TermSymbol if sym.isModuleVal => getLocalScope(sym.moduleClass.get) + case _ => + sym.tree match + case None => LocalScope(SourcePosition.NoPosition, Set.empty, Set.empty) + case Some(tree) => cache.getOrElseUpdate(sym, getLocalScope(tree)) + + private def getLocalScope(tree: Tree): LocalScope = + val inlineSyms = mutable.Set.empty[TermSymbol] + val capturedSyms = mutable.Set.empty[TermSymbol] + val localSyms = mutable.Set.empty[TermSymbol] + object Traverser extends TreeTraverser: + override def traverse(tree: Tree): Unit = + tree match + case tree: ValOrDefDef => + localSyms += tree.symbol + case bind: Bind => + localSyms += bind.symbol + case ident: Ident => + for + sym <- ident.safeTermSymbol + // sym.isLocal is not enough because of primary ctor params + if !localSyms.contains(sym) && (sym.isLocal || sym.isVal) + do capturedSyms += sym + case _ => () + + // inline call + tree match + case tree: TermReferenceTree => + for sym <- tree.safeTermSymbol if sym.isInline do inlineSyms += sym + case _ => () + + tree match + case _: TypeTree => () + case _ => super.traverse(tree) + Traverser.traverse(tree) + LocalScope(tree.pos, capturedSyms.toSet, inlineSyms.toSet) diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala b/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala index 859c439..4eecde6 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/VariableCollector.scala @@ -1,100 +1,106 @@ package ch.epfl.scala.decoder.internal -import tastyquery.Trees.* -import scala.collection.mutable +import ch.epfl.scala.decoder.ThrowOrWarn +import ch.epfl.scala.decoder.binary +import tastyquery.Contexts.* +import tastyquery.SourceFile import tastyquery.Symbols.* import tastyquery.Traversers.* -import tastyquery.Contexts.* -import tastyquery.SourcePosition +import tastyquery.Trees.* import tastyquery.Types.* -import tastyquery.Traversers -import ch.epfl.scala.decoder.ThrowOrWarn +import tastyquery.SourcePosition + +import scala.collection.mutable import scala.languageFeature.postfixOps -import tastyquery.SourceFile object VariableCollector: - def collectVariables(tree: Tree, sym: Option[TermSymbol] = None)(using Context, ThrowOrWarn): Set[LocalVariable] = - val collector = VariableCollector() + def collectVariables(scoper: Scoper, tree: Tree, sym: Option[TermSymbol] = None)(using + Context, + ThrowOrWarn + ): Set[LocalVariable] = + val collector = VariableCollector(scoper) collector.collect(tree, sym) -trait LocalVariable: - def sym: Symbol - def sourceFile: SourceFile - def startLine: Int - def endLine: Int - -object LocalVariable: - case class This(sym: ClassSymbol, sourceFile: SourceFile, startLine: Int, endLine: Int) extends LocalVariable - case class ValDef(sym: TermSymbol, sourceFile: SourceFile, startLine: Int, endLine: Int) extends LocalVariable - case class InlinedFromDef(underlying: LocalVariable, inlineCall: InlineCall) extends LocalVariable: - def sym: Symbol = underlying.sym - def startLine: Int = underlying.startLine - def endLine: Int = underlying.endLine - def sourceFile: SourceFile = underlying.sourceFile - -end LocalVariable - -class VariableCollector()(using Context, ThrowOrWarn) extends TreeTraverser: +class VariableCollector(scoper: Scoper)(using Context, ThrowOrWarn) extends TreeTraverser: private val inlinedVariables = mutable.Map.empty[TermSymbol, Set[LocalVariable]] def collect(tree: Tree, sym: Option[TermSymbol] = None): Set[LocalVariable] = val variables: mutable.Set[LocalVariable] = mutable.Set.empty - var previousTree: mutable.Stack[Tree] = mutable.Stack.empty + type ScopeTree = ClassDef | DefDef | Block | CaseDef | Inlined + var scopes: mutable.Stack[Scope] = mutable.Stack(scoper.getScope(tree)) object Traverser extends TreeTraverser: + def traverseDef(tree: Tree): Unit = + // traverse even if it's a DefDef or ClassDef + tree match + case tree: DefDef => + scoped(tree)(tree.paramLists.foreach(_.left.foreach(traverse))) + val isContextFun = tree.symbol.declaredType.returnType.safeDealias.exists(_.isContextFunction) + tree.rhs match + case Some(body @ Block(List(lambda: DefDef), expr)) if isContextFun => + // if the method returns a context function, we traverse the internal anonfun + traverseDef(lambda) + Traverser.traverse(expr) + case Some(tree) => Traverser.traverse(tree) + case None => () + case tree: ClassDef => tree.rhs.body.foreach(Traverser.traverse) + case tree: ValDef => tree.rhs.foreach(Traverser.traverse) + case _ => Traverser.traverse(tree) + override def traverse(tree: Tree): Unit = tree match - case valDefOrBind: (ValDef | Bind) => addValDefOrBind(valDefOrBind) + case tree: (ValDef | Bind) => addValDefOrBind(tree) case _ => () tree match case _: TypeTree => () - // case _: DefDef => () - case valDef: ValDef => traverse(valDef.rhs) - case bind: Bind => traverse(bind.body) + case _: DefDef | ClassDef => () case InlineCall(inlineCall) => val localVariables = - inlinedVariables.getOrElseUpdate(inlineCall.symbol, collectInlineDef(inlineCall.symbol)) - variables ++= localVariables.map(LocalVariable.InlinedFromDef(_, inlineCall)) - previousTree.push(inlineCall.termRefTree) - super.traverse(inlineCall.termRefTree) - previousTree.pop() - case _ => - previousTree.push(tree) - super.traverse(tree) - previousTree.pop() + inlinedVariables.getOrElseUpdate(inlineCall.symbol, collectInlineDef(inlineCall)) + variables ++= localVariables.map { v => + val scope = scoper.inlinedScope(v.scope, inlineCall) + LocalVariable.InlinedFromDef(v, inlineCall, scope) + } + inlineCall.args.foreach(traverse) + // inlined lambdas + for + inlineParam <- inlineCall.paramsMap.collect { case (sym, tree) if sym.isInline => tree } + inlinedLambda <- inlineParam.asLambda + lambdaTree <- inlinedLambda.tree + do traverseDef(lambdaTree) + case tree: (Block | CaseDef | Inlined) => scoped(tree)(super.traverse(tree)) + case _ => super.traverse(tree) + + private def scoped(tree: ScopeTree)(f: => Unit): Unit = + val scope = scoper.getScope(tree) + if scope.position.isFullyDefined then + scopes.push(scope) + f + scopes.pop() + else f private def addValDefOrBind(valDef: ValDef | Bind): Unit = val sym = valDef.symbol.asInstanceOf[TermSymbol] - previousTree - .collectFirst { case tree: (Block | CaseDef | DefDef) => tree } - .foreach { parentTree => - variables += - LocalVariable.ValDef( - sym, - parentTree.pos.sourceFile, - valDef.pos.startLine + 1, - parentTree.pos.endLine + 1 - ) - } + variables += LocalVariable.ValDef(sym, scopes.head) + end Traverser sym .flatMap(allOuterClasses) - .foreach(cls => - variables += LocalVariable.This( - cls, - cls.pos.sourceFile, - if cls.pos.isUnknown then -1 else cls.pos.startLine + 1, - if cls.pos.isUnknown then -1 else cls.pos.endLine + 1 - ) - ) - Traverser.traverse(tree) + .foreach: cls => + val scope = scoper.getScope(cls) + variables += LocalVariable.This(cls, scope) + Traverser.traverseDef(tree) variables.toSet end collect - private def collectInlineDef(symbol: TermSymbol): Set[LocalVariable] = - inlinedVariables(symbol) = Set.empty // break recursion - symbol.tree.toSet.flatMap(tree => collect(tree, Some(symbol))) + private def collectInlineDef(inlineCall: InlineCall): Set[LocalVariable] = + inlinedVariables(inlineCall.symbol) = Set.empty // break recursion + inlineCall.symbol.tree match + case Some(tree) => collect(tree, Some(inlineCall.symbol)) + case None => // inline def in stdLibPatches don't have trees + val scope = Scope(inlineCall.callTree.pos, Set.empty, Set.empty) + inlineCall.symbol.paramSymbols.map(LocalVariable.ValDef(_, scope)).toSet private def allOuterClasses(sym: Symbol): List[ClassSymbol] = def loop(sym: Symbol, acc: List[ClassSymbol]): List[ClassSymbol] = diff --git a/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala b/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala index d638d95..36268be 100644 --- a/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala +++ b/src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala @@ -1,18 +1,20 @@ package ch.epfl.scala.decoder.internal -import tastyquery.Symbols.* -import tastyquery.Trees.* -import tastyquery.Names.* -import tastyquery.Types.* -import tastyquery.Modifiers.* import ch.epfl.scala.decoder.* import ch.epfl.scala.decoder.binary -import tastyquery.SourcePosition +import ch.epfl.scala.decoder.binary.SourceLines import tastyquery.Contexts.* +import tastyquery.Modifiers.* +import tastyquery.Names.* import tastyquery.Signatures.* -import scala.util.control.NonFatal import tastyquery.SourceLanguage -import ch.epfl.scala.decoder.binary.SourceLines +import tastyquery.SourcePosition +import tastyquery.Symbols.* +import tastyquery.Trees.* +import tastyquery.Types.* + +import scala.annotation.tailrec +import scala.util.control.NonFatal extension (symbol: Symbol) def isTrait = symbol.isClass && symbol.asClass.isTrait @@ -60,23 +62,16 @@ extension (symbol: TermSymbol) overridingSymbolInLinearization(siteClass) == symbol def isConstructor = symbol.owner.isClass && symbol.isMethod && symbol.name == nme.Constructor + def isParam = + symbol.owner.isTerm && symbol.owner.asTerm.paramSymbols.contains(symbol) + def isParamInInlineMethod = + symbol.owner.isInline && symbol.isParam def paramSymbols: List[TermSymbol] = - symbol.tree.toList - .collect { case tree: DefDef => tree.paramLists } - .flatten - .collect { case Left(params) => params } - .flatten - .map(_.symbol) + symbol.paramSymss.collect { case Left(termSyms) => termSyms }.flatten def typeParamSymbols: List[LocalTypeParamSymbol] = - symbol.tree.toList - .collect { case tree: DefDef => tree.paramLists } - .flatten - .collect { case Right(typeParams) => typeParams } - .flatten - .map(_.symbol) - .collect { case sym: LocalTypeParamSymbol => sym } + symbol.paramSymss.collect { case Right(typeParamSyms) => typeParamSyms }.flatten extension [A, S[+X] <: IterableOnce[X]](xs: S[A]) def singleOpt: Option[A] = @@ -89,15 +84,14 @@ extension [A, S[+X] <: IterableOnce[X]](xs: S[A]) if xs.nonEmpty then xs else ys extension [T <: DecodedSymbol](xs: Seq[T]) - def singleOrThrow(symbol: binary.Symbol): T = + inline def singleOrThrow(symbol: binary.Symbol): T = singleOptOrThrow(symbol).getOrElse(notFound(symbol)) - def singleOrThrow(symbol: binary.Symbol, decodedOwner: DecodedSymbol): T = + inline def singleOrThrow(symbol: binary.Symbol, decodedOwner: DecodedSymbol): T = singleOptOrThrow(symbol).getOrElse(notFound(symbol, Some(decodedOwner))) - def singleOptOrThrow(symbol: binary.Symbol): Option[T] = - if xs.size > 1 then ambiguous(symbol, xs) - else xs.headOption + inline def singleOptOrThrow(symbol: binary.Symbol): Option[T] = + if xs.size > 1 then ambiguous(symbol, xs) else xs.headOption extension (name: TermName) def isPackageObject: Boolean = @@ -144,6 +138,15 @@ extension (tpe: Type) ctx.defn.ArrayTypeOf(tpe.elemType).erased(isReturnType = false) case _ => tpe.erased(isReturnType = false) + def expandContextFunctions(using Context, ThrowOrWarn): (List[Type], Type) = + @tailrec def rec(tpe: Type, acc: List[Type]): (List[Type], Type) = + tpe.safeDealias match + case Some(tpe: AppliedType) if tpe.tycon.isContextFunction => + val argsAsTypes = tpe.args.map(_.highIfWildcard) + rec(argsAsTypes.last, acc ::: argsAsTypes.init) + case _ => (acc, tpe) + rec(tpe, List.empty) + private def erased(isReturnType: Boolean)(using Context, ThrowOrWarn): Option[ErasedTypeRef] = tryOrNone(ErasedTypeRef.erase(tpe, SourceLanguage.Scala3, keepUnit = isReturnType)) @@ -195,18 +198,24 @@ extension (tree: Apply) tree match case tree: Apply => rec(tree.fun) case tree: TypeApply => rec(tree.fun) - case tree: TermReferenceTree => tree.safeSymbol.collect { case sym: TermSymbol => sym } + case tree: TermReferenceTree => tree.safeTermSymbol case _ => None rec(tree) extension (tree: TermReferenceTree) - def safeSymbol(using Context, ThrowOrWarn): Option[PackageSymbol | TermSymbol] = - tryOrNone(tree.symbol) + def safeTermSymbol(using Context, ThrowOrWarn): Option[TermSymbol] = + tryOrNone(tree.symbol).collect { case sym: TermSymbol => sym } extension (tree: TermTree) def safeTpe(using Context, ThrowOrWarn): Option[TermType] = tryOrNone(tree.tpe) + def asLambda(using Context): Option[TermSymbol] = + tree match + case Block(_, expr) => expr.asLambda + case Lambda(meth, _) => Some(meth.symbol.asTerm) + case _ => None + extension (pos: SourcePosition) def isFullyDefined: Boolean = !pos.isUnknown && pos.hasLineColumnInformation @@ -293,11 +302,12 @@ extension (field: binary.Instruction.Field) field.opcode == 0xb5 || field.opcode == 0xb3 extension (method: DecodedMethod) - def isGenerated: Boolean = + def isGenerated(using ctx: Context): Boolean = method match case method: DecodedMethod.ValOrDefDef => val sym = method.symbol - (sym.isGetter && (!sym.owner.isTrait || !sym.isModuleOrLazyVal)) || // getter + def isThreadUnsafe = sym.hasAnnotation(Definitions.threadUnsafeClass) + (sym.isGetter && (!sym.isModuleOrLazyVal || !sym.owner.isTrait) && !isThreadUnsafe) || // getter (sym.isLocal && sym.isModuleOrLazyVal) || // local def sym.isSetter || (sym.isSynthetic && !sym.isLocal) || @@ -318,22 +328,3 @@ extension (method: DecodedMethod) case _: DecodedMethod.SAMOrPartialFunctionConstructor => true case method: DecodedMethod.InlinedMethod => method.underlying.isGenerated case _ => false - -extension (field: DecodedField) - def isGenerated: Boolean = - field match - case field: DecodedField.ValDef => false - case field: DecodedField.ModuleVal => true - case field: DecodedField.LazyValOffset => true - case field: DecodedField.Outer => true - case field: DecodedField.SerialVersionUID => true - case field: DecodedField.Capture => true - case field: DecodedField.LazyValBitmap => true - -extension (variable: DecodedVariable) - def isGenerated: Boolean = - variable match - case variable: DecodedVariable.ValDef => false - case variable: DecodedVariable.CapturedVariable => true - case variable: DecodedVariable.This => true - case variable: DecodedVariable.AnyValThis => true diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala index ffb4fb4..f6cb1ec 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/AsmVariable.scala @@ -6,7 +6,12 @@ import ch.epfl.scala.decoder.binary.SourceLines import ch.epfl.scala.decoder.binary.Method import ch.epfl.scala.decoder.binary.Variable -class AsmVariable(val name: String, val `type`: Type, val declaringMethod: Method, val sourceLines: Option[SourceLines]) - extends Variable: +class AsmVariable( + val name: String, + val `type`: Type, + val declaringMethod: Method, + val sourceLines: Option[SourceLines], + override val isParameter: Boolean +) extends Variable: - override def toString: String = s"$name: ${`type`.name}" + override def toString: String = s"${`type`.name} $name" diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala index 3b80b8f..886ac29 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/ExtraMethodInfo.scala @@ -2,15 +2,26 @@ package ch.epfl.scala.decoder.javareflect import ch.epfl.scala.decoder.binary import org.objectweb.asm +import scala.collection.SeqMap +import ch.epfl.scala.decoder.binary.SourceLines private case class ExtraMethodInfo( sourceLines: Option[binary.SourceLines], instructions: Seq[binary.Instruction], variables: Seq[ExtraMethodInfo.Variable], - labels: Map[asm.Label, Int] -) + labelLines: SeqMap[asm.Label, Int] +): + val labels = labelLines.keys.toSeq + + // not parameters + def localVariables: Seq[ExtraMethodInfo.Variable] = + variables.filter(v => v.start != labelLines.head(0) || v.name == "this") + + def debugLines(variable: ExtraMethodInfo.Variable): Seq[Int] = + val startIdx = labels.indexOf(variable.start) + val endIdx = labels.indexOf(variable.end) + labelLines.values.slice(startIdx, endIdx).toSeq.distinct private object ExtraMethodInfo: - def empty: ExtraMethodInfo = ExtraMethodInfo(None, Seq.empty, Seq.empty, Map.empty) + def empty: ExtraMethodInfo = ExtraMethodInfo(None, Seq.empty, Seq.empty, SeqMap.empty) case class Variable(name: String, descriptor: String, signature: String, start: asm.Label, end: asm.Label) - case class LineNumberNode(line: Int, start: asm.Label, sourceName: String) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala index fa569b8..cf68e91 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectClass.scala @@ -3,9 +3,11 @@ package ch.epfl.scala.decoder.javareflect import ch.epfl.scala.decoder.binary import scala.util.matching.Regex import scala.jdk.CollectionConverters.* +import java.lang.reflect.Method +import java.lang.reflect.Constructor class JavaReflectClass(val cls: Class[?], extraInfo: ExtraClassInfo, override val classLoader: JavaReflectLoader) - extends binary.ClassType: + extends binary.BinaryClass: override def name: String = cls.getTypeName override def superclass = Option(cls.getSuperclass).map(classLoader.loadClass) override def interfaces = cls.getInterfaces.toList.map(classLoader.loadClass) @@ -32,15 +34,13 @@ class JavaReflectClass(val cls: Class[?], extraInfo: ExtraClassInfo, override va if showSpan.isEmpty then cls.toString else s"$cls $showSpan" override def declaredMethods: Seq[binary.Method] = - cls.getDeclaredMethods.map { m => - val sig = JavaReflectUtils.signature(m) - val methodInfo = extraInfo.getMethodInfo(sig) - JavaReflectMethod(m, sig, methodInfo, classLoader) - } ++ cls.getDeclaredConstructors.map { c => - val sig = JavaReflectUtils.signature(c) - val methodInfo = extraInfo.getMethodInfo(sig) - JavaReflectConstructor(c, sig, methodInfo, classLoader) - } + cls.getDeclaredMethods + .++[Method | Constructor[?]](cls.getDeclaredConstructors) + .map { m => + val sig = JavaReflectUtils.signature(m) + val methodInfo = extraInfo.getMethodInfo(sig) + JavaReflectMethod(m, sig, methodInfo, classLoader) + } override def declaredFields: Seq[binary.Field] = cls.getDeclaredFields().map(f => JavaReflectField(f, classLoader)) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala deleted file mode 100644 index e31c534..0000000 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectConstructor.scala +++ /dev/null @@ -1,42 +0,0 @@ -package ch.epfl.scala.decoder.javareflect - -import ch.epfl.scala.decoder.binary -import ch.epfl.scala.decoder.binary.Variable -import java.lang.reflect.Constructor -import java.lang.reflect.Method -import ch.epfl.scala.decoder.binary.SignedName - -class JavaReflectConstructor( - constructor: Constructor[?], - val signedName: SignedName, - extraInfos: ExtraMethodInfo, - loader: JavaReflectLoader -) extends binary.Method: - - override def variables: Seq[Variable] = Seq.empty - - override def returnType: Option[binary.Type] = Some(loader.loadClass(classOf[Unit])) - - override def returnTypeName: String = "void" - - override def declaringClass: binary.ClassType = - loader.loadClass(constructor.getDeclaringClass) - - override def allParameters: Seq[binary.Parameter] = - constructor.getParameters.map(JavaReflectParameter.apply(_, loader)) - - override def name: String = "" - - override def isBridge: Boolean = false - - override def isStatic: Boolean = false - - override def isFinal: Boolean = true - - override def isConstructor: Boolean = true - - override def toString: String = constructor.toString - - override def sourceLines: Option[binary.SourceLines] = extraInfos.sourceLines - - override def instructions: Seq[binary.Instruction] = extraInfos.instructions diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectField.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectField.scala index 4ddcc3a..e5dc83c 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectField.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectField.scala @@ -10,7 +10,7 @@ class JavaReflectField(field: Field, loader: JavaReflectLoader) extends binary.F override def sourceLines: Option[binary.SourceLines] = None - override def declaringClass: binary.ClassType = loader.loadClass(field.getDeclaringClass) + override def declaringClass: binary.BinaryClass = loader.loadClass(field.getDeclaringClass) override def isStatic: Boolean = Modifier.isStatic(field.getModifiers) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala index c9d1e74..b4cc707 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectLoader.scala @@ -7,6 +7,7 @@ import org.objectweb.asm import java.io.IOException import java.net.URLClassLoader import java.nio.file.Path +import scala.collection.SeqMap class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extends BinaryClassLoader: private val loadedClasses: mutable.Map[Class[?], JavaReflectClass] = mutable.Map.empty @@ -57,7 +58,6 @@ class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extend override def visitLabel(label: asm.Label): Unit = labels += label override def visitLineNumber(line: Int, start: asm.Label): Unit = - // println("line: " + (line) + " start: " + start + " sourceName: " + sourceName) labelLines += start -> line override def visitFieldInsn(opcode: Int, owner: String, name: String, descriptor: String): Unit = instructions += Instruction.Field(opcode, owner.replace('/', '.'), name, descriptor) @@ -85,16 +85,15 @@ class JavaReflectLoader(classLoader: ClassLoader, loadExtraInfo: Boolean) extend allLines ++= labelLines.values val sourceLines = Option.when(sourceName.nonEmpty)(SourceLines(sourceName, labelLines.values.toSeq)) var latestLine: Option[Int] = None - val labelsWithLines: mutable.Map[asm.Label, Int] = mutable.Map.empty - for label <- labels - do + val labelsAndLines = labels.toSeq.flatMap { label => latestLine = labelLines.get(label).orElse(latestLine) - latestLine.foreach(line => labelsWithLines += label -> line) + latestLine.map(label -> _) + } extraInfos += SignedName(name, descriptor) -> ExtraMethodInfo( sourceLines, instructions.toSeq, variables.toSeq, - labelsWithLines.toMap + SeqMap(labelsAndLines*) ) reader.accept(visitor, asm.Opcodes.ASM9) val sourceLines = Option.when(sourceName.nonEmpty)(SourceLines(sourceName, allLines.toSeq)) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala index aaf86c3..fd1e51e 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectMethod.scala @@ -3,30 +3,44 @@ package ch.epfl.scala.decoder.javareflect import ch.epfl.scala.decoder.binary import java.lang.reflect.Method +import java.lang.reflect.Constructor import java.lang.reflect.Modifier -import ch.epfl.scala.decoder.binary.SignedName -import ch.epfl.scala.decoder.binary.Instruction import org.objectweb.asm -import ch.epfl.scala.decoder.binary.SourceLines class JavaReflectMethod( - method: Method, - val signedName: SignedName, + method: Method | Constructor[?], + val signedName: binary.SignedName, extraInfos: ExtraMethodInfo, loader: JavaReflectLoader ) extends binary.Method: override def returnType: Option[binary.Type] = - Option(method.getReturnType).map(loader.loadClass) + method match + case _: Constructor[?] => Some(loader.loadClass(classOf[Unit])) + case m: Method => Option(m.getReturnType).map(loader.loadClass) - override def returnTypeName: String = method.getReturnType.getName - - override def declaringClass: binary.ClassType = + override def declaringClass: binary.BinaryClass = loader.loadClass(method.getDeclaringClass) - override def allParameters: Seq[binary.Parameter] = - method.getParameters.map(JavaReflectParameter.apply(_, loader)) + override def parameters: Seq[binary.Parameter] = + method.getParameters.map(JavaReflectParameter.apply(_, this, loader)) - override def name: String = method.getName + override def variables: Seq[binary.Variable] = + val localVariables = + for variable <- extraInfos.localVariables yield + val typeName = asm.Type.getType(variable.descriptor).getClassName + val isParameter = extraInfos.labels.headOption.contains(variable.start) + val sourceLines = + for + sourceName <- sourceName + lines = extraInfos.debugLines(variable) + if lines.nonEmpty + yield binary.SourceLines(sourceName, lines) + AsmVariable(variable.name, loader.loadClass(typeName), this, sourceLines, isParameter) + parameters ++ localVariables + + override def name: String = method match + case _: Constructor[?] => "" + case m: Method => m.getName override def isStatic: Boolean = Modifier.isStatic(method.getModifiers) @@ -35,26 +49,12 @@ class JavaReflectMethod( override def toString: String = if showSpan.isEmpty then method.toString else s"$method $showSpan" - override def isBridge: Boolean = method.isBridge + override def isBridge: Boolean = method match + case _: Constructor[?] => false + case m: Method => m.isBridge - override def isConstructor: Boolean = false + override def isConstructor: Boolean = method.isInstanceOf[Constructor[?]] override def sourceLines: Option[binary.SourceLines] = extraInfos.sourceLines override def instructions: Seq[binary.Instruction] = extraInfos.instructions - - override def variables: Seq[binary.Variable] = - for variable <- extraInfos.variables - yield - val typeName = asm.Type.getType(variable.descriptor).getClassName - val sourceLines = - for - sourceName <- sourceName - line <- extraInfos.labels.get(variable.start) - yield SourceLines(sourceName, Seq(line)) - AsmVariable( - variable.name, - loader.loadClass(typeName), - this, - sourceLines - ) diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectParameter.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectParameter.scala index 39fedf2..03e3784 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectParameter.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectParameter.scala @@ -4,11 +4,15 @@ import ch.epfl.scala.decoder.binary import java.lang.reflect.Parameter -class JavaReflectParameter(parameter: Parameter, loader: JavaReflectLoader) extends binary.Parameter: +class JavaReflectParameter(parameter: Parameter, val declaringMethod: binary.Method, loader: JavaReflectLoader) + extends binary.Parameter: override def name: String = parameter.getName - override def sourceLines: Option[binary.SourceLines] = None + + override def sourceLines = declaringMethod.sourceLines override def `type`: binary.Type = loader.loadClass(parameter.getType) - override def toString: String = parameter.toString + override def toString: String = + try parameter.toString + catch case _: java.lang.ArrayIndexOutOfBoundsException => parameter.getName diff --git a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectUtils.scala b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectUtils.scala index bc6d97e..6e2b56e 100644 --- a/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectUtils.scala +++ b/src/main/scala/ch/epfl/scala/decoder/javareflect/JavaReflectUtils.scala @@ -1,9 +1,10 @@ package ch.epfl.scala.decoder.javareflect -import java.lang.reflect.Method -import java.lang.reflect.Constructor import ch.epfl.scala.decoder.binary.SignedName +import java.lang.reflect.Constructor +import java.lang.reflect.Method + object JavaReflectUtils: val primitiveSigs = Map[Class[?], String]( classOf[Byte] -> "B", @@ -17,14 +18,15 @@ object JavaReflectUtils: classOf[Unit] -> "V" ) - def signature(method: Method): SignedName = + def signature(method: Method | Constructor[?]): SignedName = + val name = method match + case _: Constructor[?] => "" + case m: Method => m.getName val params = method.getParameterTypes.map(signature) - val returnType = signature(method.getReturnType) - SignedName(method.getName, s"(${params.mkString})$returnType") - - def signature(ctr: Constructor[?]): SignedName = - val params = ctr.getParameterTypes.map(signature) - SignedName("", s"(${params.mkString})V") + val returnType = method match + case _: Constructor[?] => "V" + case m: Method => signature(m.getReturnType) + SignedName(name, s"(${params.mkString})$returnType") def signature(cls: Class[?]): String = if cls.isPrimitive then primitiveSigs(cls) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala index 22cfe9e..b3ccb05 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClass.scala @@ -7,14 +7,14 @@ import scala.jdk.CollectionConverters.* import scala.util.control.NonFatal /* A class or interface */ -class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with ClassType: +class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with BinaryClass: override def classLoader: BinaryClassLoader = JdiClassLoader(ref.classLoader) - override def superclass: Option[ClassType] = ref match + override def superclass: Option[BinaryClass] = ref match case cls: com.sun.jdi.ClassType => Some(JdiClass(cls.superclass)) case _ => None - override def interfaces: Seq[ClassType] = ref match + override def interfaces: Seq[BinaryClass] = ref match case cls: com.sun.jdi.ClassType => cls.interfaces.asScala.toSeq.map(JdiClass.apply) case interface: com.sun.jdi.InterfaceType => interface.superinterfaces.asScala.toSeq.map(JdiClass.apply) @@ -30,9 +30,9 @@ class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with ClassTy override def declaredMethods: Seq[Method] = ref.methods.asScala.map(JdiMethod(_)).toSeq - override def declaredField(name: String): Option[Field] = None + override def declaredField(name: String): Option[Field] = declaredFields.find(_.name == name) - override def declaredFields: Seq[Field] = Seq.empty + override def declaredFields: Seq[Field] = ref.fields.asScala.map(JdiField(_)).toSeq private[jdi] def constantPool: ConstantPool = ConstantPool(ref.constantPool) @@ -43,3 +43,5 @@ class JdiClass(ref: com.sun.jdi.ReferenceType) extends JdiType(ref) with ClassTy override def sourceName: Option[String] = Option(ref.sourceName) private def visibleMethods: Seq[JdiMethod] = ref.visibleMethods.asScala.map(JdiMethod(_)).toSeq + + override def toString: String = ref.toString diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClassLoader.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClassLoader.scala index 812f129..e3d1e0d 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClassLoader.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiClassLoader.scala @@ -7,3 +7,5 @@ import ch.epfl.scala.decoder.binary.BinaryClassLoader class JdiClassLoader(classLoader: com.sun.jdi.ClassLoaderReference) extends BinaryClassLoader: override def loadClass(name: String): JdiClass = JdiClass(classLoader.visibleClasses.asScala.find(_.name == name).get) + + override def toString = classLoader.toString diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala index d9b730c..ca667ed 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiField.scala @@ -4,10 +4,12 @@ import ch.epfl.scala.decoder.binary.* class JdiField(field: com.sun.jdi.Field) extends Field: - override def declaringClass: ClassType = JdiClass(field.declaringType()) + override def declaringClass: BinaryClass = JdiClass(field.declaringType()) override def isStatic: Boolean = field.isStatic() override def name: String = field.name override def sourceLines: Option[SourceLines] = None override def `type`: Type = JdiType(field.`type`) + + override def toString: String = field.toString diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiLocalVariable.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiLocalVariable.scala deleted file mode 100644 index 3f0f5b7..0000000 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiLocalVariable.scala +++ /dev/null @@ -1,8 +0,0 @@ -package ch.epfl.scala.decoder.jdi - -import ch.epfl.scala.decoder.binary.* - -class JdiLocalVariable(localVariable: com.sun.jdi.LocalVariable) extends Parameter: - override def name: String = localVariable.name - override def sourceLines: Option[SourceLines] = None - override def `type`: Type = JdiType(localVariable.`type`) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala index f2e3538..cd463e7 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiMethod.scala @@ -11,18 +11,16 @@ class JdiMethod(method: com.sun.jdi.Method) extends Method: override def declaringClass: JdiClass = JdiClass(method.declaringType) - override def allParameters: Seq[Parameter] = - method.arguments.asScala.toSeq.map(JdiLocalVariable.apply(_)) + override def parameters: Seq[Parameter] = + method.arguments.asScala.toSeq.map(JdiParameter.apply(_, method)) override def variables: Seq[Variable] = - method.variables().asScala.toSeq.map(JdiVariable.apply(_, method)) + method.variables.asScala.toSeq.map(JdiVariable.apply(_, method)) override def returnType: Option[Type] = try Some(JdiType(method.returnType)) catch case e: com.sun.jdi.ClassNotLoadedException => None - override def returnTypeName: String = method.returnTypeName - override def isBridge: Boolean = method.isBridge override def isStatic: Boolean = method.isStatic @@ -43,3 +41,5 @@ class JdiMethod(method: com.sun.jdi.Method) extends Method: private def signature: String = method.signature private def bytecodes: Array[Byte] = method.bytecodes + + override def toString: String = method.toString diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiParameter.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiParameter.scala new file mode 100644 index 0000000..56241c3 --- /dev/null +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiParameter.scala @@ -0,0 +1,11 @@ +package ch.epfl.scala.decoder.jdi + +import ch.epfl.scala.decoder.binary.* + +class JdiParameter(variable: com.sun.jdi.LocalVariable, declaringMethod: JdiMethod) + extends JdiVariable(variable, declaringMethod) + with Parameter + +object JdiParameter: + def apply(variable: com.sun.jdi.LocalVariable, method: com.sun.jdi.Method): JdiParameter = + new JdiParameter(variable, JdiMethod(method)) diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiType.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiType.scala index 846e1f7..a99bdef 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiType.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiType.scala @@ -4,3 +4,5 @@ import ch.epfl.scala.decoder.binary.* class JdiType(tpe: com.sun.jdi.Type) extends Type: override def name: String = tpe.name override def sourceLines: Option[SourceLines] = None + + override def toString: String = tpe.toString diff --git a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala index f986e1a..547d1a8 100644 --- a/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala +++ b/src/main/scala/ch/epfl/scala/decoder/jdi/JdiVariable.scala @@ -2,11 +2,15 @@ package ch.epfl.scala.decoder.jdi import ch.epfl.scala.decoder.binary.* -class JdiVariable(variable: com.sun.jdi.LocalVariable, method: com.sun.jdi.Method) extends Variable: - - override def declaringMethod: Method = - JdiMethod(method) +class JdiVariable(variable: com.sun.jdi.LocalVariable, val declaringMethod: JdiMethod) extends Variable: override def name: String = variable.name override def sourceLines: Option[SourceLines] = None override def `type`: Type = JdiType(variable.`type`) + + override def toString: String = variable.toString + +object JdiVariable: + def apply(variable: com.sun.jdi.LocalVariable, method: com.sun.jdi.Method): JdiVariable = + if variable.isArgument then new JdiParameter(variable, JdiMethod(method)) + else new JdiVariable(variable, JdiMethod(method)) diff --git a/src/main/scala/tastyquery/decoder/Substituters.scala b/src/main/scala/tastyquery/decoder/Substituters.scala index 3046eb3..b735747 100644 --- a/src/main/scala/tastyquery/decoder/Substituters.scala +++ b/src/main/scala/tastyquery/decoder/Substituters.scala @@ -6,9 +6,9 @@ import tastyquery.Types.* import tastyquery.decoder.TypeMaps.* object Substituters: - def substLocalTypeParams(tp: TermType, from: List[LocalTypeParamSymbol], to: List[TypeOrWildcard])(using + def substLocalTypeParams(tp: TypeMappable, from: List[LocalTypeParamSymbol], to: List[TypeOrWildcard])(using Context - ): TermType = + ): tp.ThisTypeMappableType = new SubstLocalTypeParamsMap(from, to).apply(tp) private final class SubstLocalTypeParamsMap(from: List[LocalTypeParamSymbol], to: List[TypeOrWildcard])(using Context) diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryClassDecoderTests.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryClassDecoderTests.scala new file mode 100644 index 0000000..2ed1169 --- /dev/null +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryClassDecoderTests.scala @@ -0,0 +1,138 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.testutils.* +import tastyquery.Exceptions.* + +import scala.util.Properties + +class Scala3NextBinaryClassDecoderTests extends Scala3LtsBinaryClassDecoderTests: + override val scalaVersion = ScalaVersion.`3.next` + +class Scala3LtsBinaryClassDecoderTests extends BinaryDecoderSuite: + val scalaVersion: ScalaVersion = ScalaVersion.`3.lts` + def isScala33 = scalaVersion.isScala33 + def isScala34 = scalaVersion.isScala34 + + test("local class, trait and object by parents") { + val source = + """|package example + |object Main : + | class A + | def main(args: Array[String]): Unit = + | trait D extends A + | class C extends D + | object F extends D + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeClass("example.Main$D$1", "Main.main.D") + decoder.assertDecodeClass("example.Main$C$1", "Main.main.C") + decoder.assertDecodeClass("example.Main$F$2$", "Main.main.F") + decoder.assertDecodeMethod( + "example.Main$", + "example.Main$F$2$ F$1(scala.runtime.LazyRef F$lzy1$2)", + "Main.main.F: F", + generated = true + ) + decoder.assertDecodeMethod( + "example.Main$", + "example.Main$F$2$ F$lzyINIT1$1(scala.runtime.LazyRef F$lzy1$1)", + "Main.main.F.: F" + ) + } + + test("find local classes") { + val source = + """|package example + |class A + |trait B + |object Main: + | def m() = + | class C extends A,B + | () + | class E : + | class F + | class G extends A + | def l () = + | class C extends A + | class G extends A + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeClass("example.Main$C$1", "Main.m.C") + decoder.assertDecodeClass("example.Main$E$1$F", "Main.m.E.F") + decoder.assertDecodeClass("example.Main$G$1", "Main.m.G") + } + + test("anonymous class") { + val source = + """|package example + |class B : + | def n = 42 + |class A : + | def m(t: => Any): Int = + | val b = new B { + | def m = () + | } + | b.n + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeClass("example.A$$anon$1", "A.m.") + } + + test("local enum") { + val source = + """|package example + |object Main : + | def m = + | enum A: + | case B + | () + |""".stripMargin + + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeClass("example.Main$A$1", "Main.m.A") + } + + test("tasty-query#412"): + val decoder = initDecoder("dev.zio", "zio-interop-cats_3", "23.1.0.0")(using ThrowOrWarn.ignore) + decoder.assertDecodeClass("zio.BuildInfoInteropCats", "BuildInfoInteropCats") + decoder.assertDecodeClass("zio.interop.ZioMonadErrorE$$anon$4", "ZioMonadErrorE.adaptError.") + + test("tasty-query#414"): + val decoder = initDecoder("io.github.dieproht", "matr-dflt-data_3", "0.0.3") + decoder.assertDecodeClass( + "matr.dflt.DefaultMatrixFactory$$anon$1", + "DefaultMatrixFactory.defaultMatrixFactory.builder." + ) + + test("tasty-query#415"): + val decoder = initDecoder("com.github.mkroli", "dns4s-fs2_3", "0.21.0") + decoder.assertDecodeClass("com.github.mkroli.dns4s.fs2.DnsClientOps", "DnsClientOps") + + test("tasty-query#423"): + val decoder = initDecoder("com.typesafe.akka", "akka-stream_3", "2.8.5") + decoder.assertDecodeClass("akka.stream.scaladsl.FlowOps$passedEnd$2$", "FlowOps.zipAllFlow.passedEnd") + + test("tasty-query#424"): + val decoder = initDecoder("edu.gemini", "lucuma-itc-core_3", "0.10.0") + decoder.assertDecodeClass("lucuma.itc.ItcImpl", "ItcImpl") + + test("specialized class"): + val decoder = initDecoder("org.scala-lang", "scala-library", "2.13.12") + decoder.assertDecodeClass("scala.runtime.java8.JFunction1$mcII$sp", "JFunction1$mcII$sp") + + test("local class in value class"): + val source = + """|package example + | + |class A(self: String) extends AnyVal: + | def m(size: Int): String = + | class B: + | def m(): String = + | self.take(size) + | val b = new B + | b.m() + |""".stripMargin + // tasty-query#428 + val decoder = TestingDecoder(source, scalaVersion)(using ThrowOrWarn.ignore) + decoder.assertDecodeClass("example.A$B$1", "A.m.B") diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala index 0771c4d..3500654 100644 --- a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderStats.scala @@ -13,9 +13,10 @@ class BinaryDecoderStats extends BinaryDecoderSuite: val decoder = initDecoder("org.scala-lang", "scala3-compiler_3", "3.3.1") decoder.assertDecodeAll( expectedClasses = ExpectedCount(4426), - expectedMethods = ExpectedCount(68421, ambiguous = 25, notFound = 33), - expectedFields = ExpectedCount(12550, ambiguous = 23, notFound = 3), - expectedVariables = ExpectedCount(129844, ambiguous = 4927, notFound = 2475) + expectedMethods = ExpectedCount(68422, ambiguous = 24, notFound = 33), + expectedFields = ExpectedCount(12552, ambiguous = 23, notFound = 1), + expectedVariables = ExpectedCount(143319, ambiguous = 1885, notFound = 213) + // classFilter = Set("dotty.tools.dotc.util.HashSet") ) test("scala3-compiler:3.0.2"): @@ -23,8 +24,8 @@ class BinaryDecoderStats extends BinaryDecoderSuite: decoder.assertDecodeAll( expectedClasses = ExpectedCount(3859, notFound = 3), expectedMethods = ExpectedCount(60762, ambiguous = 24, notFound = 163), - expectedFields = ExpectedCount(10674, ambiguous = 19, notFound = 6), - expectedVariables = ExpectedCount(112306, ambiguous = 4443, notFound = 2187) + expectedFields = ExpectedCount(10672, ambiguous = 21, notFound = 6), + expectedVariables = ExpectedCount(123372, ambiguous = 1599, notFound = 1041) ) test("io.github.vigoo:zio-aws-ec2_3:4.0.5 - slow".ignore): @@ -40,7 +41,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: expectedClasses = ExpectedCount(10), expectedMethods = ExpectedCount(218), expectedFields = ExpectedCount(45), - expectedVariables = ExpectedCount(194, ambiguous = 1, notFound = 45) + expectedVariables = ExpectedCount(236, notFound = 11) ) test("net.zygfryd:jackshaft_3:0.2.2".ignore): @@ -53,7 +54,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: expectedClasses = ExpectedCount(245), expectedMethods = ExpectedCount(2755, notFound = 92), expectedFields = ExpectedCount(298), - expectedVariables = ExpectedCount(4541, ambiguous = 58, notFound = 38) + expectedVariables = ExpectedCount(4873, notFound = 8) ) test("org.clulab:processors-main_3:8.5.3".ignore): @@ -71,7 +72,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(27), ExpectedCount(174, notFound = 2), expectedFields = ExpectedCount(20, ambiguous = 4), - expectedVariables = ExpectedCount(253, ambiguous = 3, notFound = 6) + expectedVariables = ExpectedCount(299, notFound = 1) ) test("com.zengularity:benji-google_3:2.2.1".ignore): @@ -102,7 +103,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(149, notFound = 9), ExpectedCount(3546, notFound = 59), expectedFields = ExpectedCount(144, notFound = 2), - expectedVariables = ExpectedCount(14750, ambiguous = 275, notFound = 39) + expectedVariables = ExpectedCount(15225, ambiguous = 7, notFound = 10) ) test("com.evolution:scache_3:5.1.2"): @@ -114,7 +115,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(105), ExpectedCount(1509), expectedFields = ExpectedCount(161), - expectedVariables = ExpectedCount(3150, ambiguous = 51, notFound = 9) + expectedVariables = ExpectedCount(3354, ambiguous = 22, notFound = 3) ) test("com.github.j5ik2o:docker-controller-scala-dynamodb-local_:1.15.34"): @@ -126,7 +127,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(2), ExpectedCount(37), expectedFields = ExpectedCount(5), - expectedVariables = ExpectedCount(30) + expectedVariables = ExpectedCount(39) ) test("eu.ostrzyciel.jelly:jelly-grpc_3:0.5.3"): @@ -136,7 +137,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(24), ExpectedCount(353), expectedFields = ExpectedCount(61), - expectedVariables = ExpectedCount(443, ambiguous = 3, notFound = 2) + expectedVariables = ExpectedCount(480, notFound = 2) ) test("com.devsisters:zio-agones_3:0.1.0"): @@ -147,13 +148,13 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(83, notFound = 26), ExpectedCount(2804, ambiguous = 2, notFound = 5), expectedFields = ExpectedCount(258), - expectedVariables = ExpectedCount(3706, ambiguous = 17, notFound = 1, throwables = 48) + expectedVariables = ExpectedCount(3936, ambiguous = 2, notFound = 1) ) test("org.log4s:log4s_3:1.10.0".ignore): val fetchOptions = FetchOptions(keepProvided = true) val decoder = initDecoder("org.log4s", "log4s_3", "1.10.0", fetchOptions) - decoder.assertDecode("org.log4s.Warn", "java.lang.String name()", "") + decoder.assertDecodeMethod("org.log4s.Warn", "java.lang.String name()", "") test("org.virtuslab.scala-cli:cli2_3:0.1.5".ignore): val fetchOptions = FetchOptions(keepProvided = true) @@ -175,7 +176,7 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(19), ExpectedCount(158), expectedFields = ExpectedCount(32, ambiguous = 4, notFound = 2), - expectedVariables = ExpectedCount(204, notFound = 2) + expectedVariables = ExpectedCount(240, notFound = 1) ) test("io.github.valdemargr:gql-core_3:0.3.3"): @@ -184,5 +185,5 @@ class BinaryDecoderStats extends BinaryDecoderSuite: ExpectedCount(531), ExpectedCount(7267, ambiguous = 4, notFound = 1), expectedFields = ExpectedCount(851, notFound = 2), - expectedVariables = ExpectedCount(14771, ambiguous = 313, notFound = 26) + expectedVariables = ExpectedCount(15919, ambiguous = 14, notFound = 13) ) diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryFieldDecoderTests.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryFieldDecoderTests.scala new file mode 100644 index 0000000..0b8e9b4 --- /dev/null +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryFieldDecoderTests.scala @@ -0,0 +1,404 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.testutils.* +import tastyquery.Exceptions.* + +import scala.util.Properties + +class Scala3NextBinaryFieldDecoderTests extends Scala3LtsBinaryFieldDecoderTests: + override val scalaVersion: ScalaVersion = ScalaVersion.`3.next` + +class Scala3LtsBinaryFieldDecoderTests extends BinaryDecoderSuite: + val scalaVersion: ScalaVersion = ScalaVersion.`3.lts` + + def isScala33 = scalaVersion.isScala33 + def isScala34 = scalaVersion.isScala34 + + test("public and private fields") { + val source = + """|package example + | + |class A { + | var x: Int = 1 + | var `val`: Int = 1 + | private val y: String = "y" + | lazy val z: Int = 2 + | + | def foo: String = y + |} + | + |object A { + | val z: Int = 2 + | private var w: String = "w" + | private lazy val v: Int = 3 + | + | def bar: String = w + v + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "int x", "A.x: Int") + decoder.assertDecodeField("example.A", "int val", "A.val: Int") + decoder.assertDecodeField("example.A", "java.lang.String y", "A.y: String") + decoder.assertDecodeField("example.A", "java.lang.Object z$lzy1", "A.z: Int") + decoder.assertDecodeField("example.A$", "int z", "A.z: Int") + decoder.assertDecodeField("example.A$", "java.lang.String w", "A.w: String") + decoder.assertDecodeField("example.A$", "java.lang.Object v$lzy1", "A.v: Int") + } + + test("public and private objects") { + val source = + """|package example + | + |class A { + | object B + | private object C + |} + | + |object A { + | object D + | private object E + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "java.lang.Object B$lzy1", "A.B: B") + decoder.assertDecodeField("example.A", "java.lang.Object C$lzy1", "A.C: C") + decoder.assertDecodeField("example.A$", "example.A$ MODULE$", "A: A") + decoder.assertDecodeField("example.A", "long OFFSET$1", "A.: Long") + decoder.assertDecodeField("example.A", "long OFFSET$0", "A.: Long") + decoder.assertDecodeField("example.A$", "example.A$D$ D", "A.D: D") + decoder.assertDecodeField("example.A$D$", "example.A$D$ MODULE$", "A.D: D") + decoder.assertDecodeField("example.A$", "example.A$E$ E", "A.E: E") + decoder.assertDecodeField("example.A$E$", "example.A$E$ MODULE$", "A.E: E") + } + + test("fields in extended trait") { + val source = + """|package example + | + |trait A { + | private val x: Int = 1 + | private val y: Int = 2 + | val z: Int = 3 + |} + | + |class B extends A { + | val y: Int = 2 + |} + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.B", "int example$A$$x", "B.x: Int") + decoder.assertDecodeField("example.B", "int z", "B.z: Int") + // TODO fix + // decoder.assertDecodeField("example.B", "int y", "B.y: Int") + // decoder.assertDecodeField("example.B", "int example$A$$y", "B.y: Int") + } + + test("given field in extended trait") { + val source = + """|package example + | + |trait A: + | given x: Int = 1 + | + |class C extends A + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.C", "java.lang.Object x$lzy1", "C.x: Int") + } + + test("expanded names") { + val source = + """|package example + | + |trait A { + | def foo = + | enum B: + | case C, D + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField( + "example.A$B$3$", + "scala.runtime.LazyRef example$A$B$3$$$B$lzy1$3", + "A.foo.B.B.: B" + ) + } + + test("static fields in Java") { + val source = + """|package example; + | + |final class A { + | public static final int x = 1; + |} + |""".stripMargin + val javaModule = Module.fromJavaSource(source, scalaVersion) + val decoder = TestingDecoder(javaModule.mainEntry, javaModule.classpath) + decoder.assertDecodeField("example.A", "int x", "A.x: Int") + } + + test("case field in JavaLangEnum") { + val source = + """|package example + | + |enum A extends java.lang.Enum[A] : + | case B + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "example.A B", "A.B: A") + } + + test("anonymous using parameter") { + val source = + """|package example + | + |trait C + | + |class B (using C): + | def foo = summon[C] + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.B", "example.C x$1", "B.x$1: C") + } + + test("lazy val bitmap") { + val source = + """|package example + |import scala.annotation.threadUnsafe + | + |class A: + | @threadUnsafe lazy val x: Int = 1 + | @threadUnsafe lazy val y: Int = 1 + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "boolean xbitmap$1", "A.x.: Boolean") + decoder.assertDecodeField("example.A", "boolean ybitmap$1", "A.y.: Boolean") + } + + test("serialVersionUID fields") { + val source = + """|package example + | + |@SerialVersionUID(1L) + |class A + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "long serialVersionUID", "A.: Long") + } + + test("offset_m field") { + val source = + """|package example + | + |trait A { + | def foo: Int + |} + |class C: + | object B extends A { + | lazy val foo: Int = 42 + | } + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.C$B$", "long OFFSET$_m_0", "C.B.: Long") + } + + test("ambiguous module val and implicit def fields") { + val source = + """|package example + | + |object A { + | object B + | + | implicit class B (val x: Int) + | + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A$", "example.A$B$ B", "A.B: B") + } + + test("anon lazy val") { + val source = + """|package example + | + |class A: + | lazy val (a, b) = (1, 2) + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "java.lang.Object $1$$lzy1", "A.: (Int, Int)") + } + + test("outer field and param") { + val source = + """|package example + | + |class A[T](x: T){ + | class B { + | def foo: T = x + | } + | + | def bar: T = { + | class C { + | def foo: T = x + | } + | (new C).foo + | } + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A", "java.lang.Object example$A$$x", "A.x: T") + decoder.assertDecodeField("example.A$B", "example.A $outer", "A.B.: A[T]") + decoder.assertDecodeVariable("example.A$B", "void (example.A $outer)", "example.A $outer", 5, ": A[T]") + decoder.assertDecodeField("example.A$C$1", "example.A $outer", "A.bar.C.: A[T]") + decoder.assertDecodeVariable( + "example.A$C$1", + "void (example.A $outer)", + "example.A $outer", + 10, + ": A[T]" + ) + } + + test("intricated outer fields") { + val source = + """|package example + | + |trait A { + | class X + |} + | + |trait B extends A { + | class Y { + | class Z extends X + | } + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A$X", "example.A $outer", "A.X.: A") + decoder.assertDecodeField("example.B$Y$Z", "example.B$Y $outer", "B.Y.Z.: Y") + } + + test("indirect capture") { + val source = + """|package example + | + |class A(): + | def foo = + | val x: Int = 1 + | def met = x + | class B: + | def bar = met + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A$B$1", "int x$2", "A.foo.B.x.: Int") + decoder.assertDecodeVariable( + "example.A$B$1", + "void (int x$3, example.A $outer)", + "int x$3", + 8, + "x.: Int" + ) + } + + test("ambiguous indirect captures") { + val source = + """|package example + | + |class A(): + | def bar = + | val x: Int = 12 + | def getX = x + | def foo = + | val x: Int = 1 + | def met = x + | class B: + | def bar2 = met + | def bar3: Int = getX + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertAmbiguousField("example.A$B$1", "int x$3") + decoder.assertAmbiguousField("example.A$B$1", "int x$4") + } + + test("captured lazy ref") { + val source = + """|package example + |trait C + | + |class A { + | def foo = + | lazy val c: C = new C {} + | class B: + | def ct = c + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A$B$1", "scala.runtime.LazyRef c$lzy1$3", "A.foo.B.c.: C") + decoder.assertDecodeVariable( + "example.A$B$1", + "void (scala.runtime.LazyRef c$lzy1$4, example.A $outer)", + "scala.runtime.LazyRef c$lzy1$4", + 8, + "c.: C" + ) + } + + test("local class capture") { + val source = + """|package example + | + |class Foo { + | def foo = + | val x = " " + | class A: + | def bar = x + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.Foo$A$1", "java.lang.String x$1", "Foo.foo.A.x.: String") + decoder.assertDecodeVariable( + "example.Foo$A$1", + "void (java.lang.String x$2)", + "java.lang.String x$2", + 7, + "x.: String" + ) + } + + test("captured value class") { + val source = + """|package example + | + |class A(val x: Int) extends AnyVal: + | def foo = + | class B: + | def bar = x + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeField("example.A$B$1", "int $this$1", "A.foo.B.x.: Int") + } + + test("captured through inline method") { + val source = + """|package example + | + |trait C + | + |object A: + | inline def withMode(inline op: C ?=> Unit)(using C): Unit = op + | + | def foo(using C) = withMode { + | class B: + | def bar = summon[C] + | } + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertNotFoundField("example.A$B$1", "example.C x$1$1") + } diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryMethodDecoderTests.scala similarity index 51% rename from src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala rename to src/test/scala/ch/epfl/scala/decoder/BinaryMethodDecoderTests.scala index 05c465f..bf065bc 100644 --- a/src/test/scala/ch/epfl/scala/decoder/BinaryDecoderTests.scala +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryMethodDecoderTests.scala @@ -5,838 +5,14 @@ import tastyquery.Exceptions.* import scala.util.Properties -class Scala3LtsBinaryDecoderTests extends BinaryDecoderTests(ScalaVersion.`3.lts`) -class Scala3NextBinaryDecoderTests extends BinaryDecoderTests(ScalaVersion.`3.next`) +class Scala3NextBinaryMethodDecoderTests extends Scala3LtsBinaryMethodDecoderTests: + override val scalaVersion: ScalaVersion = ScalaVersion.`3.next` -abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDecoderSuite: +class Scala3LtsBinaryMethodDecoderTests extends BinaryDecoderSuite: + val scalaVersion: ScalaVersion = ScalaVersion.`3.lts` def isScala33 = scalaVersion.isScala33 def isScala34 = scalaVersion.isScala34 - test("scala3-compiler:3.3.1"): - val decoder = initDecoder("org.scala-lang", "scala3-compiler_3", "3.3.1") - decoder.assertDecodeVariable( - "scala.quoted.runtime.impl.QuoteMatcher$", - "scala.Option treeMatch(dotty.tools.dotc.ast.Trees$Tree scrutineeTree, dotty.tools.dotc.ast.Trees$Tree patternTree, dotty.tools.dotc.core.Contexts$Context x$3)", - "scala.util.boundary$Break ex", - "ex: Break[T]", - 128 - ) - - test("tailLocal variables") { - val source = - """|package example - | - |class A { - | @annotation.tailrec - | private def factAcc(x: Int, acc: Int): Int = - | if x <= 1 then List(1, 2).map(_ * acc).sum - | else factAcc(x - 1, x * acc) - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)") - // decoder.assertDecodeVariable("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)", "int acc$tailLocal1$1", "acc.: Int", 6) - } - - test("SAMOrPartialFunctionImpl") { - val source = - """|package example - | - |class A: - | def foo(x: Int) = - | val xs = List(x, x + 1, x + 2) - | xs.collect { case z if z % 2 == 0 => z } - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A$$anon$1", "boolean isDefinedAt(int x)") - decoder.assertDecodeVariable( - "example.A$$anon$1", - "boolean isDefinedAt(int x)", - "int x", - "x: A", - 6 - ) - decoder.assertDecodeVariable( - "example.A$$anon$1", - "boolean isDefinedAt(int x)", - "int z", - "z: Int", - 6 - ) - - decoder.showVariables("example.A$$anon$1", "java.lang.Object applyOrElse(int x, scala.Function1 default)") - decoder.assertDecodeVariable( - "example.A$$anon$1", - "java.lang.Object applyOrElse(int x, scala.Function1 default)", - "scala.Function1 default", - "default: A1 => B1", - 6 - ) - } - - test("inlined this") { - val source = - """|package example - | - |class A(x: Int): - | inline def foo: Int = x + x - | - |class B: - | def bar(a: A) = a.foo - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.B", "int bar(example.A a)") - decoder.assertDecodeVariable( - "example.B", - "int bar(example.A a)", - "example.A A_this", - "this: A.this.type", - 7, - generated = true - ) - } - - test("inlined param") { - val source = - """|package example - | - |class A { - | inline def foo(x: Int): Int = x + x - | - | def bar(y: Int) = foo(y + 2) - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int bar(int y)") - decoder.assertDecodeVariable("example.A", "int bar(int y)", "int x$proxy1", "x: Int", 4) - } - - test("bridge parameter") { - val source = - """|package example - |class A - |class B extends A - | - |class C: - | def foo(x: Int): A = new A - | - |class D extends C: - | override def foo(y: Int): B = new B - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.D", "example.A foo(int x)") - decoder.assertIgnoredVariable("example.D", "example.A foo(int x)", "int x", "Bridge") - } - - test("lazy val capture") { - val source = - """|package example - | - |class A { - | def foo = - | val y = 4 - | lazy val z = y + 1 - | def bar = z - | z - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int bar$1(scala.runtime.LazyInt z$lzy1$3, int y$3)") - decoder.assertDecodeVariable( - "example.A", - "int bar$1(scala.runtime.LazyInt z$lzy1$3, int y$3)", - "scala.runtime.LazyInt z$lzy1$3", - "z.: Int", - 7, - generated = true - ) - - decoder.showVariables("example.A", "int z$lzyINIT1$1(scala.runtime.LazyInt z$lzy1$1, int y$1)") - decoder.assertDecodeVariable( - "example.A", - "int z$lzyINIT1$1(scala.runtime.LazyInt z$lzy1$1, int y$1)", - "scala.runtime.LazyInt z$lzy1$1", - "z.: Int", - 7, - generated = true - ) - - decoder.showVariables("example.A", "int z$1(scala.runtime.LazyInt z$lzy1$2, int y$2)") - decoder.assertDecodeVariable( - "example.A", - "int z$1(scala.runtime.LazyInt z$lzy1$2, int y$2)", - "scala.runtime.LazyInt z$lzy1$2", - "z.: Int", - 7, - generated = true - ) - } - - test("by-name arg capture") { - val source = - """|package example - | - |class A { - | def foo(x: => Int) = ??? - | - | def bar(x: Int) = - | foo(x) - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int bar$$anonfun$1(int x$1)") - decoder.assertDecodeVariable( - "example.A", - "int bar$$anonfun$1(int x$1)", - "int x$1", - "x.: Int", - 7, - generated = true - ) - } - - test("binds") { - val source = - """|package example - | - |class B - |case class C(x: Int, y: String) extends B - |case class D(z: String) extends B - |case class E(v: Int) extends B - |case class F(w: Int) extends B - | - |class A: - | private def bar(a: B) = - | a match - | case F(w) => w - | case C(x, y) => - | x - | case D(z) => 0 - | case E(v) => 1 - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - // decoder.showVariables("example.A", "int bar(example.B a)") - // decoder.assertDecodeAll( - // ExpectedCount(2), - // ExpectedCount(37), - // expectedFields = ExpectedCount(5) - // ) - decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int w", "w: Int", 12) - decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int x", "x: Int", 14) - - } - - test("mixin and trait static forwarders") { - val source = - """|package example - | - |trait A { - | def foo(x: Int): Int = x - |} - | - |class B extends A - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.B", "int foo(int x)") - decoder.showVariables("example.A", "int foo$(example.A $this, int x)") - decoder.assertDecodeVariable("example.B", "int foo(int x)", "int x", "x: Int", 7) - decoder.assertDecodeVariable( - "example.A", - "int foo$(example.A $this, int x)", - "example.A $this", - "this: A.this.type", - 4, - generated = true - ) - } - - test("this AnyVal") { - val source = - """|package example - | - |class A(x: Int) extends AnyVal { - | def foo: Int = - | x - | - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A$", "int foo$extension(int $this)") - decoder.assertDecodeVariable( - "example.A$", - "int foo$extension(int $this)", - "int $this", - "x: Int", - 5, - generated = true - ) - } - - test("this variable") { - val source = - """|package example - | - |class A: - | def foo: Int = - | 4 - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int foo()") - decoder.assertDecodeVariable("example.A", "int foo()", "example.A this", "this: A.this.type", 5, generated = true) - } - - test("binds tuple and pattern matching") { - val source = - """|package example - | - |class A { - | def foo: Int = - | val x = (1, 2) - | val (c, d) = (3, 4) - | x match - | case (a, b) => a + b - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int foo()") - decoder.assertDecodeVariable("example.A", "int foo()", "scala.Tuple2 x", "x: (Int, Int)", 5) - decoder.assertDecodeVariable("example.A", "int foo()", "int c", "c: Int", 6) - decoder.assertDecodeVariable("example.A", "int foo()", "int d", "d: Int", 6) - decoder.assertDecodeVariable("example.A", "int foo()", "int a", "a: Int", 8) - decoder.assertDecodeVariable("example.A", "int foo()", "int b", "b: Int", 8) - } - - test("ambiguous impossible") { - val source = - """|package example - | - |class A: - | def foo(a: Boolean) = - | if (a) {val x = 1} else {val x = 2} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "void foo(boolean a)") - decoder.assertAmbiguousVariable("example.A", "void foo(boolean a)", "int x", 5) - } - - test("ambiguous variables 2") { - val source = - """|package example - | - |class A: - | def foo() = - | var i = 0 - | while i < 10 do - | val x = i - | i += 1 - | val x = 17 - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "void foo()") - decoder.assertDecodeVariable("example.A", "void foo()", "int x", "x: Int", line = 7) - decoder.assertDecodeVariable("example.A", "void foo()", "int x", "x: Int", line = 9) - } - - test("ambiguous variables") { - val source = - """|package example - | - |class A : - | def foo(a: Boolean) = - | if a then - | val x = 1 - | x - | else - | val x = "2" - | x - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "java.lang.Object foo(boolean a)") - decoder.assertDecodeVariable("example.A", "java.lang.Object foo(boolean a)", "int x", "x: Int", line = 7) - decoder.assertDecodeVariable( - "example.A", - "java.lang.Object foo(boolean a)", - "java.lang.String x", - "x: String", - line = 9 - ) - } - - test("local object") { - val source = - """|package example - | - |class A { - | def foo() = - | object B - | B - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "java.lang.Object foo()") - decoder.assertNoSuchElementVariable("example.A", "java.lang.Object foo()", "example.A$B$2$ B$1") - } - - test("local lazy val") { - val source = - """|package example - | - |class A: - | def foo() = - | lazy val x: Int = 1 - | x - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int foo()") - decoder.assertNoSuchElementVariable("example.A", "int foo()", "int x$1$lzyVal") - } - - test("array") { - val source = - """|package example - | - |class A { - | def foo() = - | val x = Array(1, 2, 3) - | x - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int[] foo()") - decoder.assertDecodeVariable("example.A", "int[] foo()", "int[] x", "x: Array[Int]", 6) - } - - test("captured param in a local def") { - val source = - """|package example - | - |class A { - | def foo(x: Int) = { - | def bar() = x - | } - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int bar$1(int x$1)") - decoder.assertDecodeVariable( - "example.A", - "int bar$1(int x$1)", - "int x$1", - "x.: Int", - 5, - generated = true - ) - } - - test("method parameter") { - val source = - """|package example - | - |class A: - | def foo(y: String) = - | println(y) - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "void foo(java.lang.String y)") - decoder.assertDecodeVariable( - "example.A", - "void foo(java.lang.String y)", - "java.lang.String y", - "y: String", - 5 - ) - } - - test("local variable") { - val source = - """|package example - | - |class A: - | def foo = - | val x: Int = 1 - | x - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.showVariables("example.A", "int foo()") - decoder.assertDecodeVariable("example.A", "int foo()", "int x", "x: Int", 6) - } - - test("capture value class") { - val source = - """|package example - | - |class A(val x: Int) extends AnyVal: - | def foo = - | class B: - | def bar = x - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A$B$1", "int $this$1", "A.foo.B.x.: Int", generated = true) - } - - test("capture inline method") { - val source = - """|package example - | - |trait C - | - |object A: - | inline def withMode(inline op: C ?=> Unit)(using C): Unit = op - | - | def foo(using C) = withMode { - | class B: - | def bar = summon[C] - | } - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertNotFoundField("example.A$B$1", "example.C x$1$1") - } - - test("anon lazy val") { - val source = - """|package example - | - |class A: - | lazy val (a, b) = (1, 2) - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "java.lang.Object $1$$lzy1", "A.: (Int, Int)") - } - - test("expanded names fields") { - val source = - """|package example - | - |trait A { - | def foo = - | enum B: - | case C, D - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField( - "example.A$B$3$", - "scala.runtime.LazyRef example$A$B$3$$$B$lzy1$3", - "A.foo.B.B.: B", - generated = true - ) - } - - test("lazy ref") { - val source = - """|package example - |trait C - | - |class A { - | def foo = - | lazy val c: C = new C {} - | class B: - | def ct = c - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField( - "example.A$B$1", - "scala.runtime.LazyRef c$lzy1$3", - "A.foo.B.c.: C", - generated = true - ) - } - - test("ambiguous indirect captures") { - val source = - """|package example - | - |class A(): - | def bar = - | val x: Int = 12 - | def getX = x - | def foo = - | val x: Int = 1 - | def met = x - | class B: - | def bar2 = met - | def bar3: Int = getX - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertAmbiguousField("example.A$B$1", "int x$3") - decoder.assertAmbiguousField("example.A$B$1", "int x$4") - } - - test("indirect capture") { - val source = - """|package example - | - |class A(): - | def foo = - | val x: Int = 1 - | def met = x - | class B: - | def bar = met - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A$B$1", "int x$2", "A.foo.B.x.: Int", generated = true) - } - - test("anonymous using parameter") { - val source = - """|package example - | - |trait C - | - |class B (using C): - | def foo = summon[C] - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.B", "example.C x$1", "B.x$1: C") - - } - - test("lazy val bitmap") { - val source = - """|package example - |import scala.annotation.threadUnsafe - | - |class A: - | @threadUnsafe lazy val x: Int = 1 - | @threadUnsafe lazy val y: Int = 1 - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "boolean xbitmap$1", "A.x.: Boolean", generated = true) - decoder.assertDecodeField("example.A", "boolean ybitmap$1", "A.y.: Boolean", generated = true) - } - - test("class defined in a method fields") { - val source = - """|package example - | - |class Foo { - | def foo = - | val x = " " - | class A: - | def bar = x - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField( - "example.Foo$A$1", - "java.lang.String x$1", - "Foo.foo.A.x.: String", - generated = true - ) - } - - test("case field in JavaLangEnum") { - val source = - """|package example - | - |enum A extends java.lang.Enum[A] : - | case B - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "example.A B", "A.B: A") - } - - test("serialVersionUID fields") { - val source = - """|package example - | - |@SerialVersionUID(1L) - |class A - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "long serialVersionUID", "A.: Long", generated = true) - } - - test("static fields in static classes Java") { - val source = - """|package example; - | - |final class A { - | public static final int x = 1; - |} - |""".stripMargin - val javaModule = Module.fromJavaSource(source, scalaVersion) - val decoder = TestingDecoder(javaModule.mainEntry, javaModule.classpath) - decoder.assertDecodeField("example.A", "int x", "A.x: Int") - } - - test("extend trait with given fields") { - val source = - """|package example - | - |trait A: - | given x: Int = 1 - | - |class C extends A - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.C", "java.lang.Object x$lzy1", "C.x: Int") - } - - test("extend traits with val fields") { - val source = - """|package example - | - |trait A { - | private val x: Int = 1 - | private val y: Int = 2 - | val z: Int = 3 - |} - | - |class B extends A { - | val y: Int = 2 - |} - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.B", "int example$A$$x", "B.x: Int") - decoder.assertDecodeField("example.B", "int z", "B.z: Int") - // TODO fix - // decoder.assertDecodeField("example.B", "int y", "B.y: Int") - // decoder.assertDecodeField("example.B", "int example$A$$y", "B.y: Int") - } - - test("notFound offset_m field") { - val source = - """|package example - | - |trait A { - | def foo: Int - |} - |class C: - | object B extends A { - | lazy val foo: Int = 42 - | } - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.C$B$", "long OFFSET$_m_0", "C.B.: Long", generated = true) - } - - test("ambiguous Object/ImplicitClass fields") { - val source = - """|package example - | - |object A { - | object B - | - | implicit class B (val x: Int) - | - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A$", "example.A$B$ B", "A.B: B") - } - - test("public and private fields") { - val source = - """|package example - | - |class A { - | var x: Int = 1 - | var `val`: Int = 1 - | private val y: String = "y" - | lazy val z: Int = 2 - | - | def foo: String = y - |} - | - |object A { - | val z: Int = 2 - | private var w: String = "w" - | private lazy val v: Int = 3 - | - | def bar: String = w + v - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "int x", "A.x: Int") - decoder.assertDecodeField("example.A", "int val", "A.val: Int") - decoder.assertDecodeField("example.A", "java.lang.String y", "A.y: String") - decoder.assertDecodeField("example.A", "java.lang.Object z$lzy1", "A.z: Int") - decoder.assertDecodeField("example.A$", "int z", "A.z: Int") - decoder.assertDecodeField("example.A$", "java.lang.String w", "A.w: String") - decoder.assertDecodeField("example.A$", "java.lang.Object v$lzy1", "A.v: Int") - } - - test("public and private objects") { - val source = - """|package example - | - |class A { - | object B - | private object C - |} - | - |object A { - | object D - | private object E - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "java.lang.Object B$lzy1", "A.B: B") - decoder.assertDecodeField("example.A", "java.lang.Object C$lzy1", "A.C: C") - decoder.assertDecodeField("example.A$", "example.A$ MODULE$", "A: A", true) - decoder.assertDecodeField("example.A", "long OFFSET$1", "A.: Long", true) - decoder.assertDecodeField("example.A", "long OFFSET$0", "A.: Long", true) - decoder.assertDecodeField("example.A$", "example.A$D$ D", "A.D: D") - decoder.assertDecodeField("example.A$D$", "example.A$D$ MODULE$", "A.D: D", true) - decoder.assertDecodeField("example.A$", "example.A$E$ E", "A.E: E") - decoder.assertDecodeField("example.A$E$", "example.A$E$ MODULE$", "A.E: E", true) - } - - test("outer field") { - val source = - """|package example - | - |class A[T](x: T){ - | class B { - | def foo: T = x - | } - | - | def bar: T = { - | class C { - | def foo: T = x - | } - | (new C).foo - | } - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A", "java.lang.Object example$A$$x", "A.x: T") - decoder.assertDecodeField("example.A$B", "example.A $outer", "A.B.: A[T]", generated = true) - decoder.assertDecodeField("example.A$C$1", "example.A $outer", "A.bar.C.: A[T]", generated = true) - } - - test("intricated outer fields") { - val source = - """|package example - | - |trait A { - | class X - |} - | - |trait B extends A { - | class Y { - | class Z extends X - | } - |} - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecodeField("example.A$X", "example.A $outer", "A.X.: A", generated = true) - decoder.assertDecodeField("example.B$Y$Z", "example.B$Y $outer", "B.Y.Z.: Y", generated = true) - } - test("mixin and static forwarders") { val source = """|package example @@ -878,49 +54,22 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco val javaSig = "java.lang.String m()" val staticTraitAccessor = "java.lang.String m$(example.A $this)" - decoder.assertDecode("example.A", javaSig, "A.m(): String") - decoder.assertDecode("example.A", staticTraitAccessor, "A.m.(): String", generated = true) - decoder.assertDecode("example.B", javaSig, "B.m.(): String", generated = true) - decoder.assertDecode("example.C", javaSig, "C.m.(): String", generated = true) - decoder.assertDecode("example.D", javaSig, "D.m(): String") - decoder.assertDecode("example.F$", javaSig, "F.m.(): String", generated = true) - decoder.assertDecode("example.F", javaSig, "F.m.(): String", generated = true) - decoder.assertDecode("example.Main$G", javaSig, "Main.G.m(): String") - decoder.assertDecode("example.Main$H", javaSig, "Main.H.m.(): String", generated = true) - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", javaSig, "A.m(): String") + decoder.assertDecodeMethod("example.A", staticTraitAccessor, "A.m.(): String", generated = true) + decoder.assertDecodeMethod("example.B", javaSig, "B.m.(): String", generated = true) + decoder.assertDecodeMethod("example.C", javaSig, "C.m.(): String", generated = true) + decoder.assertDecodeMethod("example.D", javaSig, "D.m(): String") + decoder.assertDecodeMethod("example.F$", javaSig, "F.m.(): String", generated = true) + decoder.assertDecodeMethod("example.F", javaSig, "F.m.(): String", generated = true) + decoder.assertDecodeMethod("example.Main$G", javaSig, "Main.G.m(): String") + decoder.assertDecodeMethod("example.Main$H", javaSig, "Main.H.m.(): String", generated = true) + decoder.assertDecodeMethod( "example.Main$$anon$1", javaSig, "Main.main..m.(): String", generated = true ) - decoder.assertDecode("example.Main$$anon$2", javaSig, "Main.main..m(): String") - } - - test("local class, trait and object by parents") { - val source = - """|package example - |object Main : - | class A - | def main(args: Array[String]): Unit = - | trait D extends A - | class C extends D - | object F extends D - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$D$1", "Main.main.D") - decoder.assertDecode("example.Main$C$1", "Main.main.C") - decoder.assertDecode("example.Main$F$2$", "Main.main.F") - decoder.assertDecode( - "example.Main$", - "example.Main$F$2$ F$1(scala.runtime.LazyRef F$lzy1$2)", - "Main.main.F: F", - generated = true - ) - decoder.assertDecode( - "example.Main$", - "example.Main$F$2$ F$lzyINIT1$1(scala.runtime.LazyRef F$lzy1$1)", - "Main.main.F.: F" - ) + decoder.assertDecodeMethod("example.Main$$anon$2", javaSig, "Main.main..m(): String") } test("local class and local method in a local class") { @@ -938,7 +87,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "void example$Main$Bar$1$$_$A$3()", "Main.m.….m.A(): Unit") + decoder.assertDecodeMethod("example.Main$", "void example$Main$Bar$1$$_$A$3()", "Main.m.….m.A(): Unit") } test("local methods with same name") { @@ -969,13 +118,13 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "void m$1(java.lang.String y$1)", "A.m1.m: Unit") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "void m$1(java.lang.String y$1)", "A.m1.m: Unit") + decoder.assertDecodeMethod( "example.A", "void m$2(java.lang.String y$2, java.lang.String z)", "A.m1.m.m(z: String): Unit" ) - decoder.assertDecode("example.A", "void m$3(int i)", "A.m2.m(i: Int): Unit") + decoder.assertDecodeMethod("example.A", "void m$3(int i)", "A.m2.m(i: Int): Unit") } test("getters and setters") { @@ -1010,26 +159,26 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco def getter(field: String): String = s"java.lang.String $field()" def setter(field: String, param: String = "x$1"): String = s"void ${field}_$$eq(java.lang.String $param)" - decoder.assertDecode("example.Main$", getter("x1"), "Main.x1: String", generated = true) - decoder.assertDecode("example.Main$", getter("x2"), "Main.x2: String", generated = true) - decoder.assertDecode("example.Main$", setter("x2"), "Main.x2_=(String): Unit", generated = true) + decoder.assertDecodeMethod("example.Main$", getter("x1"), "Main.x1: String", generated = true) + decoder.assertDecodeMethod("example.Main$", getter("x2"), "Main.x2: String", generated = true) + decoder.assertDecodeMethod("example.Main$", setter("x2"), "Main.x2_=(String): Unit", generated = true) // static forwarders - decoder.assertDecode("example.Main", getter("x1"), "Main.x1.: String", generated = true) - decoder.assertDecode("example.Main", getter("x2"), "Main.x2.: String", generated = true) - decoder.assertDecode( + decoder.assertDecodeMethod("example.Main", getter("x1"), "Main.x1.: String", generated = true) + decoder.assertDecodeMethod("example.Main", getter("x2"), "Main.x2.: String", generated = true) + decoder.assertDecodeMethod( "example.Main", setter("x2", param = "arg0"), "Main.x2_=.(String): Unit", generated = true ) - decoder.assertDecode("example.A", getter("a1"), "A.a1: String", generated = true) - decoder.assertDecode("example.A", getter("a2"), "A.a2: String") - decoder.assertDecode("example.B", getter("b1"), "B.b1: String", generated = true) - decoder.assertDecode("example.B", getter("b2"), "B.b2: String", generated = true) - decoder.assertDecode("example.C", getter("c1"), "C.c1: String", generated = true) - decoder.assertDecode("example.D", getter("d1"), "D.d1: String", generated = true) + decoder.assertDecodeMethod("example.A", getter("a1"), "A.a1: String", generated = true) + decoder.assertDecodeMethod("example.A", getter("a2"), "A.a2: String") + decoder.assertDecodeMethod("example.B", getter("b1"), "B.b1: String", generated = true) + decoder.assertDecodeMethod("example.B", getter("b2"), "B.b2: String", generated = true) + decoder.assertDecodeMethod("example.C", getter("c1"), "C.c1: String", generated = true) + decoder.assertDecodeMethod("example.D", getter("d1"), "D.d1: String", generated = true) } test("bridges") { @@ -1048,9 +197,9 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco def javaSig(returnType: String): String = s"java.lang.Object m()" - decoder.assertDecode("example.A", "java.lang.Object m()", "A.m(): Object") - decoder.assertDecode("example.B", "java.lang.Object m()", "B.m.(): String", generated = true) - decoder.assertDecode("example.B", "java.lang.String m()", "B.m(): String") + decoder.assertDecodeMethod("example.A", "java.lang.Object m()", "A.m(): Object") + decoder.assertDecodeMethod("example.B", "java.lang.Object m()", "B.m.(): String", generated = true) + decoder.assertDecodeMethod("example.B", "java.lang.String m()", "B.m(): String") } test("outer accessors") { @@ -1064,7 +213,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A$B$C", "example.A$B example$A$B$C$$$outer()", "A.B.C.: B.this.type", @@ -1082,28 +231,28 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "int m1(int x, int y)", "Main.m1(using x: Int, y: Int): Int") - decoder.assertDecode("example.Main$", "int m2(int x)", "Main.m2(implicit x: Int): Int") - decoder.assertDecode( + decoder.assertDecodeMethod("example.Main$", "int m1(int x, int y)", "Main.m1(using x: Int, y: Int): Int") + decoder.assertDecodeMethod("example.Main$", "int m2(int x)", "Main.m2(implicit x: Int): Int") + decoder.assertDecodeMethod( "example.Main$", "void m3(java.lang.String x$1, int x$2)", "Main.m3(using String, Int): Unit" ) // static forwarders - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main", "int m1(int arg0, int arg1)", "Main.m1.(using x: Int, y: Int): Int", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main", "int m2(int arg0)", "Main.m2.(implicit x: Int): Int", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main", "void m3(java.lang.String arg0, int arg1)", "Main.m3.(using String, Int): Unit", @@ -1112,28 +261,6 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco } - test("find local classes") { - val source = - """|package example - |class A - |trait B - |object Main: - | def m() = - | class C extends A,B - | () - | class E : - | class F - | class G extends A - | def l () = - | class C extends A - | class G extends A - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$C$1", "Main.m.C") - decoder.assertDecode("example.Main$E$1$F", "Main.m.E.F") - decoder.assertDecode("example.Main$G$1", "Main.m.G") - } - test("local class in signature") { val source = """|package example @@ -1149,8 +276,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | t |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "example.Main$D$1 m$1(example.Main$D$1 t)", "Main.A.….m.m(t: D): D") - decoder.assertDecode("example.Main$A$B$1", "void ()", "Main.A.….B.(): Unit") + decoder.assertDecodeMethod("example.Main$", "example.Main$D$1 m$1(example.Main$D$1 t)", "Main.A.….m.m(t: D): D") + decoder.assertDecodeMethod("example.Main$A$B$1", "void ()", "Main.A.….B.(): Unit") } test("operator-like names") { @@ -1162,8 +289,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | class ++ |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.$plus$plus", "int $plus$plus$1()", "++.m.++: Int") - decoder.assertDecode("example.$plus$plus$$plus$plus$2", "++.m.++") + decoder.assertDecodeMethod("example.$plus$plus", "int $plus$plus$1()", "++.m.++: Int") + decoder.assertDecodeClass("example.$plus$plus$$plus$plus$2", "++.m.++") } test("extension method of value classes") { @@ -1177,14 +304,14 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A$", "java.lang.String m$extension(java.lang.String $this)", "A.m(): String") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A$", "java.lang.String m$extension(java.lang.String $this)", "A.m(): String") + decoder.assertDecodeMethod( "example.A", "java.lang.String m$extension(java.lang.String arg0)", "A.m.(): String", generated = true ) - decoder.assertDecode("example.A", "void (java.lang.String x)", "A.(x: String): Unit") + decoder.assertDecodeMethod("example.A", "void (java.lang.String x)", "A.(x: String): Unit") } test("local method inside a value class") { @@ -1208,8 +335,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A$", "java.lang.String m$2(java.lang.String t)", "A.m.m(t: String): String") - decoder.assertDecode("example.A$", "java.lang.String m$1()", "A.m.m: String") + decoder.assertDecodeMethod("example.A$", "java.lang.String m$2(java.lang.String t)", "A.m.m(t: String): String") + decoder.assertDecodeMethod("example.A$", "java.lang.String m$1()", "A.m.m: String") } test("multi parameter lists") { @@ -1225,7 +352,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |class A |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "java.lang.String m(example.A a)", "Main.m()(a: A): String") + decoder.assertDecodeMethod("example.Main$", "java.lang.String m(example.A a)", "Main.m()(a: A): String") } test("lazy initializer") { @@ -1247,10 +374,10 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A$", "java.lang.String a()", "A.a: String", generated = true) - decoder.assertDecode("example.A$", "java.lang.String b()", "A.b: String", generated = true) - decoder.assertDecode("example.B", "java.lang.String b()", "B.b: String") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A$", "java.lang.String a()", "A.a: String", generated = true) + decoder.assertDecodeMethod("example.A$", "java.lang.String b()", "A.b: String", generated = true) + decoder.assertDecodeMethod("example.B", "java.lang.String b()", "B.b: String") + decoder.assertDecodeMethod( "example.B", "java.lang.String b$(example.B $this)", "B.b.: String", @@ -1258,12 +385,17 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco ) // new in Scala 3.3.0 - decoder.assertDecode("example.A$", "java.lang.Object a$lzyINIT1()", "A.a.: String") - decoder.assertDecode("example.A$", "java.lang.Object b$lzyINIT1()", "A.b.: String", generated = true) + decoder.assertDecodeMethod("example.A$", "java.lang.Object a$lzyINIT1()", "A.a.: String") + decoder.assertDecodeMethod( + "example.A$", + "java.lang.Object b$lzyINIT1()", + "A.b.: String", + generated = true + ) // static forwarders - decoder.assertDecode("example.A", "java.lang.String a()", "A.a.: String", generated = true) - decoder.assertDecode("example.A", "java.lang.String b()", "A.b.: String", generated = true) + decoder.assertDecodeMethod("example.A", "java.lang.String a()", "A.a.: String", generated = true) + decoder.assertDecodeMethod("example.A", "java.lang.String b()", "A.b.: String", generated = true) } test("synthetic methods of case class") { @@ -1274,32 +406,47 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.String toString()", "A.toString(): String", generated = true) - decoder.assertDecode("example.A", "example.A copy(java.lang.String a)", "A.copy(a: String): A", generated = true) - decoder.assertDecode("example.A", "int hashCode()", "A.hashCode(): Int", generated = true) - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "java.lang.String toString()", "A.toString(): String", generated = true) + decoder.assertDecodeMethod( + "example.A", + "example.A copy(java.lang.String a)", + "A.copy(a: String): A", + generated = true + ) + decoder.assertDecodeMethod("example.A", "int hashCode()", "A.hashCode(): Int", generated = true) + decoder.assertDecodeMethod( "example.A", "boolean equals(java.lang.Object x$0)", "A.equals(Any): Boolean", generated = true ) - decoder.assertDecode("example.A", "int productArity()", "A.productArity: Int", generated = true) - decoder.assertDecode("example.A", "java.lang.String productPrefix()", "A.productPrefix: String", generated = true) - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "int productArity()", "A.productArity: Int", generated = true) + decoder.assertDecodeMethod( + "example.A", + "java.lang.String productPrefix()", + "A.productPrefix: String", + generated = true + ) + decoder.assertDecodeMethod( "example.A", "java.lang.Object productElement(int n)", "A.productElement(n: Int): Any", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "scala.collection.Iterator productIterator()", "A.productIterator.: Iterator[Any]", generated = true ) - decoder.assertDecode("example.A$", "example.A apply(java.lang.String a)", "A.apply(a: String): A", generated = true) - decoder.assertDecode("example.A$", "example.A unapply(example.A x$1)", "A.unapply(A): A", generated = true) + decoder.assertDecodeMethod( + "example.A$", + "example.A apply(java.lang.String a)", + "A.apply(a: String): A", + generated = true + ) + decoder.assertDecodeMethod("example.A$", "example.A unapply(example.A x$1)", "A.unapply(A): A", generated = true) } test("anonymous functions") { @@ -1314,36 +461,19 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | List("").map(x => x + 1) |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String m$$anonfun$1(boolean x)", "A.B.m.(x: Boolean): String" ) - decoder.assertDecode("example.A", "java.lang.String $anonfun$1(int x)", "A.B.m.(x: Int): String") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "java.lang.String $anonfun$1(int x)", "A.B.m.(x: Int): String") + decoder.assertDecodeMethod( "example.A", "java.lang.String m$$anonfun$2(java.lang.String x)", "A.m.(x: String): String" ) } - test("anonymous class") { - val source = - """|package example - |class B : - | def n = 42 - |class A : - | def m(t: => Any): Int = - | val b = new B { - | def m = () - | } - | b.n - | - |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A$$anon$1", "A.m.") - } - test("this.type") { val source = """|package example @@ -1354,7 +484,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "example.A m()", "A.m(): A.this.type") + decoder.assertDecodeMethod("example.A", "example.A m()", "A.m(): A.this.type") } test("inline def with anonymous class and method") { @@ -1375,8 +505,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B", "int $anonfun$1(int x)", "example.m.(x: Int): Int") - decoder.assertDecode("example.B$$anon$1", "example.m.") + decoder.assertDecodeMethod("example.B", "int $anonfun$1(int x)", "example.m.(x: Int): Int") + decoder.assertDecodeClass("example.B$$anon$1", "example.m.") } test("SAM and partial functions") { @@ -1394,39 +524,39 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$$anon$1", "Main.") - decoder.assertDecode( + decoder.assertDecodeClass("example.Main$$anon$1", "Main.") + decoder.assertDecodeMethod( "example.Main$$anon$1", "int compare(java.lang.String x, java.lang.String y)", "Main..compare(x: String, y: String): Int" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$$anon$1", "int compare(java.lang.Object x, java.lang.Object y)", "Main..compare.(x: String, y: String): Int", generated = true ) - decoder.assertDecode( + decoder.assertDecodeClass( "example.Main$$anon$2", "Main." ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$$anon$2", "boolean isDefinedAt(java.lang.String x)", "Main..isDefinedAt(x: String): Boolean" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$$anon$2", "boolean isDefinedAt(java.lang.Object x)", "Main..isDefinedAt.(x: String): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$$anon$2", "java.lang.Object applyOrElse(java.lang.String x, scala.Function1 default)", "Main..applyOrElse[A1, B1](x: A1, default: A1 => B1): B1" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$$anon$2", "java.lang.Object applyOrElse(java.lang.Object x, scala.Function1 default)", "Main..applyOrElse.[A1, B1](x: A1, default: A1 => B1): B1", @@ -1446,14 +576,14 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.String m$default$1()", "A.m.: String") - decoder.assertDecode("example.A", "int m$default$2()", "A.m.: Int") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "java.lang.String m$default$1()", "A.m.: String") + decoder.assertDecodeMethod("example.A", "int m$default$2()", "A.m.: Int") + decoder.assertDecodeMethod( "example.A$", "java.lang.String $lessinit$greater$default$1()", "A..: String" ) - decoder.assertDecode("example.A$", "int $lessinit$greater$default$2()", "A..: Int") + decoder.assertDecodeMethod("example.A$", "int $lessinit$greater$default$2()", "A..: Int") } test("matches on return types") { @@ -1470,14 +600,14 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "int m(scala.collection.immutable.List xs)", "A.m(xs: List[Int]): Int") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "int m(scala.collection.immutable.List xs)", "A.m(xs: List[Int]): Int") + decoder.assertDecodeMethod( "example.B", "int m(scala.collection.immutable.List xs)", "B.m.(xs: List[Int]): Int", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String m(scala.collection.immutable.List xs)", "B.m(xs: List[String]): String" @@ -1524,7 +654,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco val decoder = TestingDecoder(source, scalaVersion) def assertDecode(javaSig: String, expected: String)(using munit.Location): Unit = - decoder.assertDecode("example.Main$", javaSig, expected) + decoder.assertDecodeMethod("example.Main$", javaSig, expected) assertDecode("example.A m(example.A a)", "Main.m(a: A): A") assertDecode("example.A$B mbis(example.A$B b)", "Main.mbis(b: A.B): A.B") @@ -1555,7 +685,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "int m1(java.lang.String x)", "A.m1(x: \"a\"): 1") + decoder.assertDecodeMethod("example.A", "int m1(java.lang.String x)", "A.m1(x: \"a\"): 1") } test("type aliases") { @@ -1571,7 +701,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "java.lang.String m(example.A x)", "Main.m(x: Foo): Bar") + decoder.assertDecodeMethod("example.Main$", "java.lang.String m(example.A x)", "Main.m(x: Foo): Bar") } test("refined types") { @@ -1592,8 +722,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "example.B m1()", "Main.m1(): A & B {...}") - decoder.assertDecode("example.Main$", "java.lang.Object m2()", "Main.m2(): Object {...}") + decoder.assertDecodeMethod("example.Main$", "example.B m1()", "Main.m1(): A & B {...}") + decoder.assertDecodeMethod("example.Main$", "java.lang.Object m2()", "Main.m2(): Object {...}") } test("type parameters") { @@ -1610,8 +740,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B", "example.A m1(example.A x)", "B.m1(x: X): X") - decoder.assertDecode("example.B", "example.A m2(example.A x)", "B.m2[T](x: T): T") + decoder.assertDecodeMethod("example.B", "example.A m1(example.A x)", "B.m1(x: X): X") + decoder.assertDecodeMethod("example.B", "example.A m2(example.A x)", "B.m2[T](x: T): T") } test("nested classes") { @@ -1628,7 +758,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "scala.Enumeration$Value today()", "Main.today(): Enumeration.Value") + decoder.assertDecodeMethod("example.Main$", "scala.Enumeration$Value today()", "Main.today(): Enumeration.Value") } test("matches Null and Nothing") { @@ -1641,8 +771,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "scala.runtime.Nothing$ m(int[] xs)", "Main.m(xs: Array[Int]): Nothing") - decoder.assertDecode( + decoder.assertDecodeMethod("example.Main$", "scala.runtime.Nothing$ m(int[] xs)", "Main.m(xs: Array[Int]): Nothing") + decoder.assertDecodeMethod( "example.Main$", "scala.runtime.Null$ m(java.lang.String[] xs)", "Main.m(xs: Array[String]): Null" @@ -1658,7 +788,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Main$", "java.lang.Object m(java.lang.Object xs)", "Main.m[T](xs: Array[T]): Array[T]" @@ -1674,7 +804,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.Object m(java.lang.Object x)", "A.m[T](x: B[T]): B[T]") + decoder.assertDecodeMethod("example.A", "java.lang.Object m(java.lang.Object x)", "A.m[T](x: B[T]): B[T]") } test("constructors and trait constructors") { @@ -1688,8 +818,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |class B extends A |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "void $init$(example.A $this)", "A.(): Unit") - decoder.assertDecode("example.B", "void ()", "B.(): Unit") + decoder.assertDecodeMethod("example.A", "void $init$(example.A $this)", "A.(): Unit") + decoder.assertDecodeMethod("example.B", "void ()", "B.(): Unit") } test("vararg type") { @@ -1701,7 +831,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String m(scala.collection.immutable.Seq as)", "A.m(as: String*): String" @@ -1722,8 +852,12 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "java.lang.String $amp(example.$less$greater x)", "Main.&(x: <>): String") - decoder.assertDecode("example.$less$greater", "example.$less$greater m()", "<>.m: <>") + decoder.assertDecodeMethod( + "example.Main$", + "java.lang.String $amp(example.$less$greater x)", + "Main.&(x: <>): String" + ) + decoder.assertDecodeMethod("example.$less$greater", "example.$less$greater m()", "<>.m: <>") } test("local recursive method") { @@ -1742,7 +876,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "int rec$1(int x, int acc)", "Main.fac.rec(x: Int, acc: Int): Int") + decoder.assertDecodeMethod("example.Main$", "int rec$1(int x, int acc)", "Main.fac.rec(x: Int, acc: Int): Int") } test("local lazy initializer") { @@ -1761,13 +895,13 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String y$1(scala.runtime.LazyRef y$lzy1$2, java.lang.String x$2)", "A.m.y: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String y$lzyINIT1$1(scala.runtime.LazyRef y$lzy1$1, java.lang.String x$1)", "A.m.y.: String" @@ -1796,8 +930,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Outer", "java.lang.String example$Outer$$foo()", "Outer.foo: String") - decoder.assertDecode("example.A$", "int example$A$$$m()", "A.m: Int") + decoder.assertDecodeMethod("example.Outer", "java.lang.String example$Outer$$foo()", "Outer.foo: String") + decoder.assertDecodeMethod("example.A$", "int example$A$$$m()", "A.m: Int") } test("type lambda") { @@ -1811,21 +945,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$", "example.Foo foo()", "Main.foo: Foo[[X] =>> Either[X, Int]]") - } - - test("local enum") { - val source = - """|package example - |object Main : - | def m = - | enum A: - | case B - | () - |""".stripMargin - - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Main$A$1", "Main.m.A") + decoder.assertDecodeMethod("example.Main$", "example.Foo foo()", "Main.foo: Foo[[X] =>> Either[X, Int]]") } test("package object") { @@ -1835,8 +955,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.package$", "java.lang.String foo()", "example.foo: String") - decoder.assertDecode( + decoder.assertDecodeMethod("example.package$", "java.lang.String foo()", "example.foo: String") + decoder.assertDecodeMethod( "example.package", "java.lang.String foo()", "example.foo.: String", @@ -1851,8 +971,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |def foo: String = ??? |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.Test$package$", "java.lang.String foo()", "example.foo: String") - decoder.assertDecode( + decoder.assertDecodeMethod("example.Test$package$", "java.lang.String foo()", "example.foo: String") + decoder.assertDecodeMethod( "example.Test$package", "java.lang.String foo()", "example.foo.: String", @@ -1870,8 +990,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.String m()", "A.m: String", generated = true) - decoder.assertDecode("example.A", "java.lang.String m(java.lang.String x)", "A.m(x: String): String") + decoder.assertDecodeMethod("example.A", "java.lang.String m()", "A.m: String", generated = true) + decoder.assertDecodeMethod("example.A", "java.lang.String m(java.lang.String x)", "A.m(x: String): String") } test("adapted anon fun") { @@ -1883,7 +1003,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "boolean m$$anonfun$adapted$1(java.lang.Object _$1)", "A.m..(Char): Boolean", @@ -1923,33 +1043,33 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B1$", "scala.Function0 B1$$superArg$1()", "B1..: () => \"\"") - decoder.assertDecode( + decoder.assertDecodeMethod("example.B1$", "scala.Function0 B1$$superArg$1()", "B1..: () => \"\"") + decoder.assertDecodeMethod( "example.B1$", "scala.Function1 example$B1$$$B2$$superArg$1()", "B1.B2..: String => String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C1", "scala.Function0 C1$superArg$1()", "C1..: () => \"\"" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C1", "scala.Function1 example$C1$$C2$$superArg$1()", "C1.C2..: String => String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C1", "scala.Function0 example$C1$$_$C3$superArg$1$1()", "C1.m.C3..: () => \"\"" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C1", "scala.Function0 example$C1$$_$$anon$superArg$1$1()", "C1.m...: () => \"\"" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C1$C3$1", "scala.Function1 example$C1$C3$1$$C4$superArg$1()", "C1.m.….C4..: String => String" @@ -1966,33 +1086,50 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | def m(x: String): Int ?=> String ?=> String = ??? |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( - "example.A", - if isScala33 then "java.lang.String m(int x, java.lang.String evidence$1)" - else "java.lang.String m(int x, java.lang.String contextual$1)", - "A.m(x: Int): String ?=> String" - ) - decoder.assertDecode( - "example.A", - if isScala33 then "int m(int evidence$2, java.lang.String evidence$3)" - else "int m(int contextual$2, java.lang.String contextual$3)", - "A.m(): (Int, String) ?=> Int" - ) - decoder.assertDecode( - "example.A", - if isScala33 then "java.lang.String m(java.lang.String x, int evidence$4, java.lang.String evidence$5)" - else "java.lang.String m(java.lang.String x, int contextual$4, java.lang.String contextual$5)", - "A.m(x: String): Int ?=> String ?=> String" - ) - if isScala34 then - val source = + if isScala33 then + decoder.assertDecodeMethod( + "example.A", + "java.lang.String m(int x, java.lang.String evidence$1)", + "A.m(x: Int): String ?=> String" + ) + decoder.assertDecodeMethod( + "example.A", + "int m(int evidence$2, java.lang.String evidence$3)", + "A.m(): (Int, String) ?=> Int" + ) + decoder.assertDecodeMethod( + "example.A", + "java.lang.String m(java.lang.String x, int evidence$4, java.lang.String evidence$5)", + "A.m(x: String): Int ?=> String ?=> String" + ) + else + decoder.assertDecodeMethod( + "example.A", + "java.lang.String m(int x, java.lang.String contextual$1)", + "A.m(x: Int): String ?=> String" + ) + decoder.assertDecodeMethod( + "example.A", + "int m(int contextual$2, java.lang.String contextual$3)", + "A.m(): (Int, String) ?=> Int" + ) + decoder.assertDecodeMethod( + "example.A", + "java.lang.String m(java.lang.String x, int contextual$4, java.lang.String contextual$5)", + "A.m(x: String): Int ?=> String ?=> String" + ) + val source2 = """|package example | |class A: | def mbis: ? ?=> String = ??? |""".stripMargin - val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.String mbis(java.lang.Object contextual$1)", "A.mbis: ? ?=> String") + val decoder2 = TestingDecoder(source, scalaVersion) + decoder2.assertDecodeMethod( + "example.A", + "java.lang.String mbis(java.lang.Object contextual$1)", + "A.mbis: ? ?=> String" + ) } test("trait param") { @@ -2005,11 +1142,11 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) // todo fix: should be a BinaryTraitParamGetter - decoder.assertDecode("example.B", "int x()", "B.x: Int", generated = true) - decoder.assertDecode("example.B", "int y()", "B.y: Int", generated = true) - decoder.assertDecode("example.B", "void y_$eq(int x$1)", "B.y_=(Int): Unit", generated = true) - decoder.assertDecode("example.B", "int example$A$$z()", "B.z: Int", generated = true) - decoder.assertDecode("example.B", "java.lang.String example$A$$x$4()", "B.x$4: String", generated = true) + decoder.assertDecodeMethod("example.B", "int x()", "B.x: Int", generated = true) + decoder.assertDecodeMethod("example.B", "int y()", "B.y: Int", generated = true) + decoder.assertDecodeMethod("example.B", "void y_$eq(int x$1)", "B.y_=(Int): Unit", generated = true) + decoder.assertDecodeMethod("example.B", "int example$A$$z()", "B.z: Int", generated = true) + decoder.assertDecodeMethod("example.B", "java.lang.String example$A$$x$4()", "B.x$4: String", generated = true) } test("lifted try") { @@ -2033,11 +1170,11 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | def m4 = "" + m3 |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.String liftedTree1$1()", "A.: \"\" | \"\"") - decoder.assertDecode("example.A", "java.lang.String liftedTree2$1()", "A.: \"\" | \"\"") - decoder.assertDecode("example.A", "java.lang.String liftedTree3$1()", "A.m1.: \"\" | \"\"") - decoder.assertDecode("example.A", "int liftedTree4$1()", "A.m1.m2.: 2 | 3") - decoder.assertDecode("example.A", "java.lang.String liftedTree5$1()", "A.m4.: \"\" | \"\"") + decoder.assertDecodeMethod("example.A", "java.lang.String liftedTree1$1()", "A.: \"\" | \"\"") + decoder.assertDecodeMethod("example.A", "java.lang.String liftedTree2$1()", "A.: \"\" | \"\"") + decoder.assertDecodeMethod("example.A", "java.lang.String liftedTree3$1()", "A.m1.: \"\" | \"\"") + decoder.assertDecodeMethod("example.A", "int liftedTree4$1()", "A.m1.m2.: 2 | 3") + decoder.assertDecodeMethod("example.A", "java.lang.String liftedTree5$1()", "A.m4.: \"\" | \"\"") } test("by-name args") { @@ -2054,9 +1191,9 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |} |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.Object foo(scala.Function0 x)", "A.foo[T](x: => T): T") - decoder.assertDecode("example.A", "java.lang.String $init$$$anonfun$1()", "A.: String") - decoder.assertDecode("example.A", "int m$$anonfun$1()", "A.m.: Int") + decoder.assertDecodeMethod("example.A", "java.lang.Object foo(scala.Function0 x)", "A.foo[T](x: => T): T") + decoder.assertDecodeMethod("example.A", "java.lang.String $init$$$anonfun$1()", "A.: String") + decoder.assertDecodeMethod("example.A", "int m$$anonfun$1()", "A.m.: Int") } test("inner object") { @@ -2074,14 +1211,19 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "example.A$B$ B()", "A.B: B") - decoder.assertDecode("example.A", "example.A$B$ B$(example.A $this)", "A.B.: B", generated = true) - decoder.assertDecode("example.C$", "example.A$B$ B()", "C.B: B", generated = true) - decoder.assertDecode("example.E", "example.E$F$ F()", "E.F: F", generated = true) - decoder.assertDecode("example.E", "example.A$B$ B()", "E.B: B", generated = true) - decoder.assertDecode("example.C$", "java.lang.Object B$lzyINIT1()", "C.B.: B", generated = true) - decoder.assertDecode("example.E", "java.lang.Object F$lzyINIT1()", "E.F.: F") - decoder.assertDecode("example.E", "java.lang.Object B$lzyINIT2()", "E.B.: B", generated = true) + decoder.assertDecodeMethod("example.A", "example.A$B$ B()", "A.B: B") + decoder.assertDecodeMethod( + "example.A", + "example.A$B$ B$(example.A $this)", + "A.B.: B", + generated = true + ) + decoder.assertDecodeMethod("example.C$", "example.A$B$ B()", "C.B: B", generated = true) + decoder.assertDecodeMethod("example.E", "example.E$F$ F()", "E.F: F", generated = true) + decoder.assertDecodeMethod("example.E", "example.A$B$ B()", "E.B: B", generated = true) + decoder.assertDecodeMethod("example.C$", "java.lang.Object B$lzyINIT1()", "C.B.: B", generated = true) + decoder.assertDecodeMethod("example.E", "java.lang.Object F$lzyINIT1()", "E.F.: F") + decoder.assertDecodeMethod("example.E", "java.lang.Object B$lzyINIT2()", "E.B.: B", generated = true) } test("static forwarder") { @@ -2095,7 +1237,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |object B extends A[String] |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String foo(java.lang.Object arg0)", "B.foo.(x: String): String", @@ -2112,7 +1254,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |class B(foo: String) extends A(foo) |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B", "java.lang.String foo$accessor()", "B.foo: String", generated = true) + decoder.assertDecodeMethod("example.B", "java.lang.String foo$accessor()", "B.foo: String", generated = true) } test("trait setters") { @@ -2127,19 +1269,19 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |object C extends A |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "void example$A$_setter_$example$A$$foo_$eq(java.lang.String x$0)", "B.foo.(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C$", "void example$A$_setter_$example$A$$foo_$eq(java.lang.String x$0)", "C.foo.(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C", "void example$A$_setter_$example$A$$foo_$eq(java.lang.String arg0)", "C.foo..(String): Unit", @@ -2161,19 +1303,19 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String example$B$$super$foo(java.lang.Object x)", "B.foo.(x: T): String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C", "java.lang.String example$B$$super$foo(java.lang.String x)", "C.foo.(x: String): String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C", "java.lang.String example$B$$super$foo(java.lang.Object x)", "C.foo..(x: String): String", @@ -2202,24 +1344,24 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val javaModule = Module.fromJavaSource(javaSource, scalaVersion) val decoder = TestingDecoder(source, scalaVersion, javaModule.classpath) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String m(java.lang.Object[] args)", "B.m.(args: Any*): String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String m(scala.collection.immutable.Seq args)", "B.m(args: Any*): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "int m(java.lang.String[] args)", "B.m.(args: String*): Int", generated = true ) - decoder.assertDecode("example.B", "int m(scala.collection.immutable.Seq args)", "B.m(args: String*): Int") + decoder.assertDecodeMethod("example.B", "int m(scala.collection.immutable.Seq args)", "B.m(args: String*): Int") } test("specialized methods") { @@ -2232,38 +1374,38 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |object B extends A |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "boolean apply(double x)", "A.apply(x: Double): Boolean") - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "boolean apply(double x)", "A.apply(x: Double): Boolean") + decoder.assertDecodeMethod( "example.A", "java.lang.Object apply(java.lang.Object v1)", "A.apply.(x: Double): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "boolean apply$mcZD$sp(double x)", "A.apply.(x: Double): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "int apply$mcII$sp(int x$0)", "A.apply.(x: Double): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "boolean apply(double arg0)", "B.apply.(x: Double): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "boolean apply$mcZD$sp(double arg0)", "B.apply..(x: Double): Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "int apply$mcII$sp(int arg0)", "B.apply..(x: Double): Boolean", @@ -2286,8 +1428,12 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | inline override def m[T](x: => T): T = x |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B", "java.lang.String x$proxy2$1(java.lang.String s$1)", "B.m.: String") - decoder.assertDecode("example.B$", "java.lang.Object x$proxy1$1(scala.Function0 x$1)", "B.m.: T") + decoder.assertDecodeMethod( + "example.B", + "java.lang.String x$proxy2$1(java.lang.String s$1)", + "B.m.: String" + ) + decoder.assertDecodeMethod("example.B$", "java.lang.Object x$proxy1$1(scala.Function0 x$1)", "B.m.: T") } test("inline accessor") { @@ -2309,61 +1455,61 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | inline def m: String = x + x |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String inline$x$i2(example.A$AA x$0)", "A.: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String inline$x$i2$(example.A $this, example.A$AA x$0)", "A..: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "void inline$x_$eq$i2(example.A$AA x$0, java.lang.String x$0)", "A.(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "void inline$x_$eq$i2$(example.A $this, example.A$AA x$0, java.lang.String x$0)", "A..(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String inline$x$i2(example.A$AA x$0)", "B..: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "void inline$x_$eq$i2(example.A$AA x$0, java.lang.String x$0)", "B..(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.String inline$y()", "B..: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "void inline$y_$eq(java.lang.String arg0)", "B..(String): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C$", "java.lang.String inline$x$extension(java.lang.String $this)", "C.: String", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.C", "java.lang.String inline$x$extension(java.lang.String arg0)", "C..: String", @@ -2379,7 +1525,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | val x: String => String = identity |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A$", "java.lang.Object $deserializeLambda$(java.lang.invoke.SerializedLambda arg0)", "A.$deserializeLambda$(arg0: SerializedLambda): Object" @@ -2395,7 +1541,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | case C extends A("c") |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "void (java.lang.String x, java.lang.String _$name, int _$ordinal)", "A.(x: String): Unit" @@ -2413,14 +1559,14 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | (x, y) |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "java.lang.Object $1$$lzyINIT1()", "A..: (String, String)") - decoder.assertDecode("example.A", "scala.Tuple2 $1$()", "A.: (String, String)", generated = true) - decoder.assertDecode( + decoder.assertDecodeMethod("example.A", "java.lang.Object $1$$lzyINIT1()", "A..: (String, String)") + decoder.assertDecodeMethod("example.A", "scala.Tuple2 $1$()", "A.: (String, String)", generated = true) + decoder.assertDecodeMethod( "example.A", "scala.Tuple2 $2$$lzyINIT1$1(scala.runtime.LazyRef $2$$lzy1$1)", "A.m..: (String, String)" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "scala.Tuple2 $2$$1(scala.runtime.LazyRef $2$$lzy1$2)", "A.m.: (String, String)", @@ -2441,7 +1587,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | () |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", "java.lang.String example$A$$_$m3$1$(example.A $this)", "A.m1.m3.: String", @@ -2463,13 +1609,13 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | } yield formatted |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", if isScala33 then "scala.Option m$$anonfun$2(scala.Tuple2 x$1)" else "scala.Option m$$anonfun$1(scala.Tuple2 x$1)", "A.m.((String, String)): Option[String]" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.A", if isScala33 then "scala.Option m$$anonfun$2$$anonfun$2(scala.Tuple2 x$1)" else "scala.Option m$$anonfun$1$$anonfun$2(scala.Tuple2 x$1)", @@ -2505,37 +1651,37 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String $anonfun$1(scala.collection.immutable.List xs$1, java.lang.String x)", "Test.test.(x: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String $anonfun$2(example.Logger Logger_this$1, scala.Function1 f$proxy1$1, java.lang.String x)", "Logger.m2.(x: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String test$$anonfun$1(java.lang.String name$1, java.lang.String y)", "Test.test..(y: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String test$$anonfun$2(java.lang.String a$proxy1$1, java.lang.String y)", "Test.test..(y: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String test$$anonfun$3(java.lang.String y)", "Test.test..(y: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String test$$anonfun$4(java.lang.String x$2, java.lang.String y)", "Test.test..(y: String): String" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test", "java.lang.String $anonfun$1$$anonfun$1(java.lang.String x$1, java.lang.String y)", "Test.test..(y: String): String" @@ -2545,7 +1691,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco test("tastyquery#395"): assume(!isJava8) val decoder = initDecoder("de.sciss", "desktop-core_3", "0.11.4") - decoder.assertDecode( + decoder.assertDecodeMethod( "de.sciss.desktop.impl.LogPaneImpl$textPane$", "boolean apply$mcZD$sp(double x$0)", "LogPaneImpl.textPane.apply.(str: String): Unit", @@ -2554,20 +1700,20 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco test("tasty-query#397 and tasty-query#413"): val decoder = initDecoder("com.github.xuwei-k", "httpz_3", "0.8.0") - decoder.assertDecode("httpz.package$$anon$1", "httpz.ActionZipAp.") - decoder.assertDecode("httpz.InterpretersTemplate$$anon$5", "InterpretersTemplate.times.….apply.") - decoder.assertDecode( + decoder.assertDecodeClass("httpz.package$$anon$1", "httpz.ActionZipAp.") + decoder.assertDecodeClass("httpz.InterpretersTemplate$$anon$5", "InterpretersTemplate.times.….apply.") + decoder.assertDecodeMethod( "httpz.package$", "java.lang.Object httpz$package$$anon$1$$_$ap$$anonfun$1(scala.Function1 _$2, java.lang.Object _$3)", "httpz.ActionZipAp.….ap.(A => B, A): B" ) - decoder.assertDecode( + decoder.assertDecodeMethod( "httpz.Response", "scalaz.Equal responseEqual(scalaz.Equal arg0)", "Response.responseEqual.[A](implicit Equal[A]): Equal[Response[A]]", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "httpz.Core", "scalaz.$bslash$div jsonResponse$$anonfun$3$$anonfun$1(argonaut.DecodeJson A$2, httpz.Request request$1, httpz.Response json)", "Core.jsonResponse..(json: Response[Json]): \\/[Error, Response[A]]" @@ -2575,7 +1721,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco test("tasty-query#398"): val decoder = initDecoder("io.github.ashwinbhaskar", "sight-client_3", "0.1.2") - decoder.assertDecode( + decoder.assertDecodeMethod( "sight.client.SightClientImpl", "java.lang.String b$1(scala.Tuple2 x$1$2)", "SightClientImpl.constructPayload.…..b: String" @@ -2605,7 +1751,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | rec(c) |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.Test$", "scala.Product foo$$anonfun$1(scala.Tuple2 x)", "Foo.drop1[A](a: A): DropUnits[A]" @@ -2614,7 +1760,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco test("tasty-query#402"): val decoder = initDecoder("com.softwaremill.sttp.client3", "opentelemetry_3", "3.6.1") - decoder.assertDecode( + decoder.assertDecodeMethod( "sttp.client3.opentelemetry.OpenTelemetryTracingBackend", "java.lang.Object send$$anonfun$2$$anonfun$1$$anonfun$1(sttp.client3.RequestT request$5, scala.collection.mutable.Map carrier$5)", "OpenTelemetryTracingBackend.send..: F[Response[T]]" @@ -2637,8 +1783,8 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.B", "void m(java.lang.String x)", "B.m(x: Type): Unit") - decoder.assertDecode("example.C", "void m(java.lang.Integer x)", "C.m(x: Value[T]): Unit") + decoder.assertDecodeMethod("example.B", "void m(java.lang.String x)", "B.m(x: Type): Unit") + decoder.assertDecodeMethod("example.C", "void m(java.lang.Integer x)", "C.m(x: Value[T]): Unit") } test("tasty-query#407") { @@ -2655,74 +1801,92 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | def test: Consumer[String] = m(println) |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.A", "void test$$anonfun$1(java.lang.String x)", "A.test.(x: String): Unit") + decoder.assertDecodeMethod( + "example.A", + "void test$$anonfun$1(java.lang.String x)", + "A.test.(x: String): Unit" + ) } test("scala3-compiler:3.3.1"): val decoder = initDecoder("org.scala-lang", "scala3-compiler_3", "3.3.1") - decoder.assertDecode( + decoder.assertDecodeMethod( "scala.quoted.runtime.impl.QuotesImpl", "boolean scala$quoted$runtime$impl$QuotesImpl$$inline$xCheckMacro()", "QuotesImpl.: Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "dotty.tools.dotc.printing.RefinedPrinter", "void dotty$tools$dotc$printing$RefinedPrinter$$inline$myCtx_$eq(dotty.tools.dotc.core.Contexts$Context x$0)", "RefinedPrinter.(Contexts.Context): Unit", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "dotty.tools.dotc.transform.sjs.PrepJSInterop$OwnerKind", "int inline$baseKinds$extension(int arg0)", "PrepJSInterop.OwnerKind..: Int", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "org.scalajs.ir.Trees$OptimizerHints", "boolean inline$extension(int arg0)", "Trees.OptimizerHints.inline.: Boolean", generated = true ) - decoder.assertDecode( + decoder.assertDecodeMethod( "dotty.tools.package", "java.lang.Object unreachable$default$1()", "tools.unreachable..: Any", generated = true ) - // decoder.assertDecode( + // decoder.assertDecodeMethod( // "dotty.tools.dotc.printing.Formatting$StringFormatter", // "java.lang.String assemble$$anonfun$1(java.lang.String str)", // "" // ) - decoder.assertDecode( + decoder.assertDecodeMethod( "dotty.tools.dotc.core.tasty.TreeUnpickler", "dotty.tools.dotc.ast.Trees$Tree dotty$tools$dotc$core$tasty$TreeUnpickler$TreeReader$$_$_$$anonfun$18(dotty.tools.dotc.core.Contexts$Context x$1$19, dotty.tools.dotc.core.tasty.TreeUnpickler$TreeReader $this$tailLocal1$1)", "TreeUnpickler.readTpt.()(using Contexts.Context): tpd.Tree", generated = true ) - decoder.assertNotFoundVariable( - "scala.quoted.runtime.impl.QuotesImpl$reflect$defn$", - "dotty.tools.dotc.core.Symbols$Symbol TupleClass(int arity)", - "dotty.tools.dotc.core.Types$TypeRef x$proxy1", - 2816 + decoder.assertDecodeMethod( + "dotty.tools.dotc.typer.Implicits$OfTypeImplicits", + "scala.collection.immutable.List refs()", + "Implicits.OfTypeImplicits.refs: List[Types.ImplicitRef]" ) - - test("tasty-query#412"): - val decoder = initDecoder("dev.zio", "zio-interop-cats_3", "23.1.0.0")(using ThrowOrWarn.ignore) - decoder.assertDecode("zio.BuildInfoInteropCats", "BuildInfoInteropCats") - decoder.assertDecode("zio.interop.ZioMonadErrorE$$anon$4", "ZioMonadErrorE.adaptError.") - - test("tasty-query#414"): - val decoder = initDecoder("io.github.dieproht", "matr-dflt-data_3", "0.0.3") - decoder.assertDecode( - "matr.dflt.DefaultMatrixFactory$$anon$1", - "DefaultMatrixFactory.defaultMatrixFactory.builder." + decoder.assertDecodeMethod( + "dotty.tools.dotc.typer.Typer", + "dotty.tools.dotc.core.Types$TermRef $anonfun$78(dotty.tools.dotc.core.Contexts$Context x$4$68, dotty.tools.dotc.core.Types$TermRef ref$4, dotty.tools.dotc.core.Denotations$SingleDenotation alt)", + "Typer.adapt1.adaptOverloaded.(alt: Denotations.SingleDenotation): TermRef" + ) + decoder.assertDecodeMethod( + "org.scalajs.ir.Trees$JSGlobalRef", + "java.lang.Object $init$$$anonfun$1(java.lang.String name$1)", + "Trees.JSGlobalRef.: Any" + ) + /* decoder.assertDecodeMethod( + "scala.collection.IterableOnceOps", + "java.lang.Object maxBy(scala.Function1 f, scala.math.Ordering ord)", + "" + ) */ + decoder.assertDecodeMethod( + "dotty.tools.dotc.typer.Applications$", + "dotty.tools.dotc.typer.Applications$tupleFold$2$ tupleFold$1(dotty.tools.dotc.core.Contexts$Context x$2$19, scala.runtime.LazyRef tupleFold$lzy1$2)", + "Applications.foldApplyTupleType.tupleFold: tupleFold", + generated = true + ) + decoder.assertDecodeMethod( + "dotty.tools.backend.sjs.JSCodeGen", + "org.scalajs.ir.Trees$Tree $anonfun$43(org.scalajs.ir.Trees$VarRef overloadVar$6, org.scalajs.ir.Position pos$67, dotty.tools.backend.sjs.JSCodeGen$ConstructorTree _$21)", + "JSCodeGen.genJSClassCtorBody.postStats.(ConstructorTree[SplitSecondaryJSCtor]): Trees.Tree" + ) + decoder.assertDecodeMethod( + "dotty.tools.dotc.typer.Synthesizer", + "scala.collection.immutable.List synthArgManifests$1(dotty.tools.dotc.typer.Synthesizer$ManifestKind kind$4, long span$22, dotty.tools.dotc.core.Contexts$Context evidence$17$3, java.lang.Object tp)", + "Synthesizer.manifestFactoryOf.…..synthArgManifests(tp: Manifestable): List[tpd.Tree]" ) - - test("tasty-query#415"): - val decoder = initDecoder("com.github.mkroli", "dns4s-fs2_3", "0.21.0") - decoder.assertDecode("com.github.mkroli.dns4s.fs2.DnsClientOps", "DnsClientOps") test("bug: Type.of creates capture".ignore): val source = @@ -2740,7 +1904,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco | |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode( + decoder.assertDecodeMethod( "example.B", "java.lang.Object m$$anonfun$1(scala.quoted.Type tpe$1, java.lang.String x)", "B.m.(x: String): q.reflect.TypeRepr" @@ -2748,7 +1912,7 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco test("bug: not an outer".ignore): val decoder = initDecoder("com.disneystreaming", "weaver-monix-core_3", "0.6.15")(using ThrowOrWarn.ignore) - decoder.assertDecode( + decoder.assertDecodeMethod( "weaver.monixcompat.PureTaskSuite", "weaver.SourceLocation$ weaver$SourceLocationMacro$Here$$$outer()", "" @@ -2764,34 +1928,6 @@ abstract class BinaryDecoderTests(scalaVersion: ScalaVersion) extends BinaryDeco |object $ |""".stripMargin val decoder = TestingDecoder(source, scalaVersion) - decoder.assertDecode("example.$", "java.lang.String $()", "$.$: String") - decoder.assertDecode("example.$$", "$") + decoder.assertDecodeMethod("example.$", "java.lang.String $()", "$.$: String") + decoder.assertDecodeClass("example.$$", "$") } - - test("tasty-query#423"): - val decoder = initDecoder("com.typesafe.akka", "akka-stream_3", "2.8.5") - decoder.assertDecode("akka.stream.scaladsl.FlowOps$passedEnd$2$", "FlowOps.zipAllFlow.passedEnd") - - test("tasty-query#424"): - val decoder = initDecoder("edu.gemini", "lucuma-itc-core_3", "0.10.0") - decoder.assertDecode("lucuma.itc.ItcImpl", "ItcImpl") - - test("specialized class"): - val decoder = initDecoder("org.scala-lang", "scala-library", "2.13.12") - decoder.assertDecode("scala.runtime.java8.JFunction1$mcII$sp", "JFunction1$mcII$sp") - - test("local class in value class"): - val source = - """|package example - | - |class A(self: String) extends AnyVal: - | def m(size: Int): String = - | class B: - | def m(): String = - | self.take(size) - | val b = new B - | b.m() - |""".stripMargin - // tasty-query#428 - val decoder = TestingDecoder(source, scalaVersion)(using ThrowOrWarn.ignore) - decoder.assertDecode("example.A$B$1", "A.m.B") diff --git a/src/test/scala/ch/epfl/scala/decoder/BinaryVariableDecoderTests.scala b/src/test/scala/ch/epfl/scala/decoder/BinaryVariableDecoderTests.scala new file mode 100644 index 0000000..c6a0652 --- /dev/null +++ b/src/test/scala/ch/epfl/scala/decoder/BinaryVariableDecoderTests.scala @@ -0,0 +1,710 @@ +package ch.epfl.scala.decoder + +import ch.epfl.scala.decoder.testutils.* +import tastyquery.Exceptions.* + +import scala.util.Properties + +class Scala3NextBinaryVariableDecoderTests extends Scala3LtsBinaryVariableDecoderTests: + override val scalaVersion = ScalaVersion.`3.next` + +class Scala3LtsBinaryVariableDecoderTests extends BinaryDecoderSuite: + val scalaVersion: ScalaVersion = ScalaVersion.`3.lts` + def isScala33 = scalaVersion.isScala33 + def isScala34 = scalaVersion.isScala34 + + test("local variable") { + val source = + """|package example + | + |class A: + | def foo = + | val x: Int = 1 + | x + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int foo()", "int x", 6, "x: Int") + } + + test("local array variable") { + val source = + """|package example + | + |class A { + | def foo() = + | val x = Array(1, 2, 3) + | x + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int[] foo()", "int[] x", 6, "x: Array[Int]") + } + + test("local lazy val".ignore) { + val source = + """|package example + | + |class A: + | def foo() = + | lazy val x: Int = 1 + | x + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int foo()", "int x$1$lzyVal", 6, "x: Int") + } + + test("local module val".ignore) { + val source = + """|package example + | + |class A { + | def foo() = + | object B + | B + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "java.lang.Object foo()", " example.A$B$2$ B$1", 6, "B: B.type") + } + + test("method parameter") { + val source = + """|package example + | + |class A: + | def foo(y: String) = + | println(y) + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "void foo(java.lang.String y)", "java.lang.String y", 5, "y: String") + } + + test("constructor parameters and variables") { + val source = + """|package example + | + |class A(x: Int): + | private val y = { + | val z = x * x + | z + | } + | + | def this(x: Int, y: Int) = + | this(x + y) + | val z = x + y + | println(z) + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "void (int x)", "int x", 3, "x: Int") + decoder.assertDecodeVariable("example.A", "void (int x)", "int y", 6, "y: Int") + decoder.assertDecodeVariable("example.A", "void (int x)", "int z", 7, "z: Int") + decoder.assertDecodeVariable("example.A", "void (int x, int y)", "int x", 10, "x: Int") + decoder.assertDecodeVariable("example.A", "void (int x, int y)", "int y", 10, "y: Int") + decoder.assertDecodeVariable("example.A", "void (int x, int y)", "int z", 12, "z: Int") + } + + test("contextual parameters") { + val source = + """|package example + | + |class A: + | def m(x: Int): String ?=> String = ??? + | def m(): (Int, String) ?=> Int = ??? + | def m(x: String): Int ?=> String ?=> String = ??? + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + if isScala33 then + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(int x, java.lang.String evidence$1)", + "java.lang.String evidence$1", + 4, + ": String" + ) + decoder.assertDecodeVariable( + "example.A", + "int m(int evidence$2, java.lang.String evidence$3)", + "int evidence$2", + 5, + ": Int" + ) + decoder.assertDecodeVariable( + "example.A", + "int m(int evidence$2, java.lang.String evidence$3)", + "java.lang.String evidence$3", + 5, + ": String" + ) + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(java.lang.String x, int evidence$4, java.lang.String evidence$5)", + "int evidence$4", + 6, + ": Int" + ) + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(java.lang.String x, int evidence$4, java.lang.String evidence$5)", + "java.lang.String evidence$5", + 6, + ": String" + ) + else + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(int x, java.lang.String contextual$1)", + "java.lang.String contextual$1", + 4, + ": String" + ) + decoder.assertDecodeVariable( + "example.A", + "int m(int contextual$2, java.lang.String contextual$3)", + "int contextual$2", + 5, + ": Int" + ) + decoder.assertDecodeVariable( + "example.A", + "int m(int contextual$2, java.lang.String contextual$3)", + "java.lang.String contextual$3", + 5, + ": String" + ) + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(java.lang.String x, int contextual$4, java.lang.String contextual$5)", + "int contextual$4", + 6, + ": Int" + ) + decoder.assertDecodeVariable( + "example.A", + "java.lang.String m(java.lang.String x, int contextual$4, java.lang.String contextual$5)", + "java.lang.String contextual$5", + 6, + ": String" + ) + } + + test("ambiguous local variables") { + val source = + """|package example + | + |class A: + | def foo() = + | var i = 0 + | while i < 10 do + | val x = i + | i += 1 + | val x = 17 + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + // TODO fix ambiguity somehow + decoder.assertAmbiguousVariable("example.A", "void foo()", "int x", 8) + decoder.assertDecodeVariable("example.A", "void foo()", "int x", 9, "x: Int") + } + + test("ambiguous variables and parameters") { + val source = + """|package example + | + |class A : + | def foo(a: Boolean) = + | if a then + | val x = 1 + | x + | else + | val x = "2" + | x + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "java.lang.Object foo(boolean a)", "int x", 7, "x: Int") + decoder.assertDecodeVariable("example.A", "java.lang.Object foo(boolean a)", "java.lang.String x", 9, "x: String") + } + + test("failing ambiguous local variables") { + val source = + """|package example + | + |class A: + | def foo(a: Boolean) = + | if (a) {val x = 1} else {val x = 2} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertAmbiguousVariable("example.A", "void foo(boolean a)", "int x", 5) + } + + test("binds for tuple") { + val source = + """|package example + | + |class A { + | def foo: Int = + | val x = (1, 2) + | val (c, d) = (3, 4) + | x match + | case (a, b) => a + b + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int foo()", "scala.Tuple2 x", 5, "x: (Int, Int)") + decoder.assertDecodeVariable("example.A", "int foo()", "int c", 6, "c: Int") + decoder.assertDecodeVariable("example.A", "int foo()", "int d", 6, "d: Int") + decoder.assertDecodeVariable("example.A", "int foo()", "int a", 8, "a: Int") + decoder.assertDecodeVariable("example.A", "int foo()", "int b", 8, "b: Int") + } + + test("binds for case classes") { + val source = + """|package example + | + |class B + |case class C(x: Int, y: String) extends B + |case class D(z: String) extends B + |case class E(v: Int) extends B + |case class F(w: Int) extends B + | + |class A: + | private def bar(a: B) = + | a match + | case F(w) => w + | case C(x, y) => + | x + | case D(z) => 0 + | case E(v) => 1 + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int w", 12, "w: Int") + decoder.assertDecodeVariable("example.A", "int bar(example.B a)", "int x", 14, "x: Int") + } + + test("this variable") { + val source = + """|package example + | + |class A: + | def foo: Int = + | 4 + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int foo()", "example.A this", 5, "this: A.this.type") + } + + test("value class this") { + val source = + """|package example + | + |class A(x: Int) extends AnyVal { + | def foo: Int = + | Seq(1, 2).map(_ + x).sum + | + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A$", "int foo$extension(int $this)", "int $this", 5, "x: Int") + decoder.assertDecodeVariable( + "example.A$", + "int foo$extension$$anonfun$1(int $this$1, int _$1)", + "int $this$1", + 5, + "x.: Int" + ) + decoder.assertDecodeVariable( + "example.A$", + "int foo$extension$$anonfun$1(int $this$1, int _$1)", + "int _$1", + 5, + ": Int" + ) + } + + test("parameters of mixin and trait static forwarders") { + val source = + """|package example + | + |trait A { + | def foo(x: Int): Int = x + |} + | + |class B extends A + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.B", "int foo(int x)", "int x", 7, "x: Int") + decoder.assertDecodeVariable( + "example.A", + "int foo$(example.A $this, int x)", + "example.A $this", + 4, + "this: A.this.type" + ) + } + + test("captured param in a local def") { + val source = + """|package example + | + |class A { + | def foo(x: Int) = { + | def bar() = x + | } + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int bar$1(int x$1)", "int x$1", 5, "x.: Int") + } + + test("by-name arg capture") { + val source = + """|package example + | + |class A { + | def foo(x: => Int) = ??? + | + | def bar(x: Int) = + | foo(x) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int bar$$anonfun$1(int x$1)", "int x$1", 7, "x.: Int") + } + + test("lazy val capture") { + val source = + """|package example + | + |class A { + | def foo = + | val y = 4 + | lazy val z = y + 1 + | def bar = z + | z + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable( + "example.A", + "int bar$1(scala.runtime.LazyInt z$lzy1$3, int y$3)", + "scala.runtime.LazyInt z$lzy1$3", + 7, + "z.: Int" + ) + decoder.assertDecodeVariable( + "example.A", + "int z$lzyINIT1$1(scala.runtime.LazyInt z$lzy1$1, int y$1)", + "scala.runtime.LazyInt z$lzy1$1", + 7, + "z.: Int" + ) + decoder.assertDecodeVariable( + "example.A", + "int z$1(scala.runtime.LazyInt z$lzy1$2, int y$2)", + "scala.runtime.LazyInt z$lzy1$2", + 7, + "z.: Int" + ) + } + + test("bridge parameter") { + val source = + """|package example + |class A + |class B extends A + | + |class C: + | def foo(x: Int): A = new A + | + |class D extends C: + | override def foo(y: Int): B = new B + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertIgnoredVariable("example.D", "example.A foo(int x)", "int x", "Bridge") + } + + test("inlined param") { + val source = + """|package example + | + |class A { + | inline def foo(x: Int): Int = x + x + | + | def bar(y: Int) = foo(y + 2) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A", "int bar(int y)", "int x$proxy1", 4, "x: Int") + } + + test("inlined this") { + val source = + """|package example + | + |class A(x: Int): + | inline def foo: Int = x + x + | + |class B: + | def bar(a: A) = a.foo + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.B", "int bar(example.A a)", "example.A A_this", 7, "this: A.this.type") + } + + test("partial function variables") { + val source = + """|package example + | + |class A: + | def foo(x: Int) = + | val xs = List(x, x + 1, x + 2) + | xs.collect { case z if z % 2 == 0 => z } + | + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable("example.A$$anon$1", "boolean isDefinedAt(int x)", "int x", 6, "x: A") + decoder.assertDecodeVariable("example.A$$anon$1", "boolean isDefinedAt(int x)", "int z", 6, "z: Int") + decoder.assertDecodeVariable( + "example.A$$anon$1", + "java.lang.Object applyOrElse(int x, scala.Function1 default)", + "scala.Function1 default", + 6, + "default: A1 => B1" + ) + } + + test("tail-local variables".ignore): + val source = + """|package example + | + |class A { + | @annotation.tailrec + | private def factAcc(x: Int, acc: Int): Int = + | if x <= 1 then List(1, 2).map(_ * acc).sum + | else factAcc(x - 1, x * acc) + |} + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.showVariables("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)") + // decoder.assertDecodeVariable("example.A", "int factAcc$$anonfun$1(int acc$tailLocal1$1, int _$1)", "int acc$tailLocal1$1", 6, "acc.: Int") + + test("inlined lambda"): + val source1 = + """|package example + | + |class Context + | + |object Context: + | inline def inContext[T](c: Context)(inline op: Context ?=> T): T = + | op(using c) + |""".stripMargin + val source2 = + """|package example + | + |import Context.* + | + |class A: + | def m: String = + | val c = Context() + | inContext(c) { + | val x = "foo" + | x + summon[Context].toString + | } + |""".stripMargin + val decoder = TestingDecoder(Seq(source1, source2), scalaVersion) + decoder.assertDecodeVariable("example.A", "java.lang.String m()", "java.lang.String x", 10, "x: String") + + test("captured anon val"): + val source = + """|package example + | + |class A: + | def foo(x: String): String = "a" + x + | + |class B: + | var a: A = A() + | def m: List[String] = + | List("bar").map(a.foo) + |""".stripMargin + val decoder = TestingDecoder(source, scalaVersion) + decoder.assertDecodeVariable( + "example.B", + "java.lang.String m$$anonfun$1(example.A $1$$1, java.lang.String x)", + "example.A $1$$1", + 9, + ".: A" + ) + + test("scala3-compiler:3.3.1"): + val decoder = initDecoder("org.scala-lang", "scala3-compiler_3", "3.3.1") + /* decoder.assertNotFoundVariable( + "scala.quoted.runtime.impl.QuotesImpl$reflect$defn$", + "dotty.tools.dotc.core.Symbols$Symbol TupleClass(int arity)", + "dotty.tools.dotc.core.Types$TypeRef x$proxy1", + 2816 + ) */ + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.QuoteMatcher$", + "scala.Option treeMatch(dotty.tools.dotc.ast.Trees$Tree scrutineeTree, dotty.tools.dotc.ast.Trees$Tree patternTree, dotty.tools.dotc.core.Contexts$Context x$3)", + "scala.util.boundary$Break ex", + 128, + "ex: Break[T]" + ) + decoder.assertDecodeVariable( + "org.scalajs.ir.VersionChecks", + "void (java.lang.String current, java.lang.String binaryEmitted)", + "java.lang.String current", + 26, + "current: String" + ) + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.QuoteMatcher$", + "scala.collection.immutable.Seq $eq$qmark$eq(dotty.tools.dotc.ast.Trees$Tree scrutinee0, dotty.tools.dotc.ast.Trees$Tree pattern0, scala.collection.immutable.Map x$3, dotty.tools.dotc.core.Contexts$Context x$4, scala.util.boundary$Label evidence$4)", + "dotty.tools.dotc.ast.Trees$Tree s", + 192, + "s: Tree[Types.Type]" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.typer.Namer$Completer", + "scala.collection.immutable.List completerTypeParams(dotty.tools.dotc.core.Symbols$Symbol sym, dotty.tools.dotc.core.Contexts$Context x$2)", + "dotty.tools.dotc.core.Symbols$Symbol sym", + 829, + "sym: Symbols.Symbol" + ) + decoder.assertDecodeVariable( + "dotty.tools.scripting.Util$", + "scala.collection.IterableOnce collectMainMethods$1$$anonfun$1(java.net.URLClassLoader cl$1, java.lang.String targetPath$1, java.io.File packageMember)", + "java.lang.String targetPath$1", + 35, + "targetPath.: String" + ) + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.printers.SourceCode$SourceCodePrinter", + "java.lang.Object printTree$$anonfun$adapted$5(scala.Option elideThis$48, scala.runtime.ObjectRef argsPrefix$2, scala.collection.immutable.List args1$2)", + "scala.runtime.ObjectRef argsPrefix$2", + 408, + "argsPrefix.: String" + ) + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.printers.SourceCode$SourceCodePrinter", + "scala.quoted.runtime.impl.printers.SourceCode$SourceCodePrinter printTree(java.lang.Object tree, scala.Option elideThis)", + "scala.runtime.ObjectRef argsPrefix", + 390, + "argsPrefix: String" + ) + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.printers.SourceCode$SourceCodePrinter", + "java.lang.String escapedString$$anonfun$adapted$1(java.lang.Object ch)", + "java.lang.Object ch", + 1435, + "ch: Char" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.typer.Implicits$OfTypeImplicits", + "scala.collection.immutable.List refs()", + "scala.collection.mutable.ListBuffer buf", + 276, + "buf: ListBuffer[Types.TermRef]" + ) + // decode from Java class + decoder.assertDecodeVariable( + "dotty.tools.io.JDK9Reflectors", + "java.util.jar.JarFile newJarFile(java.io.File arg0, boolean arg1, int arg2, java.lang.Object arg3)", + "java.lang.Object arg3", + 63, + "x$3: Object" + ) + decoder.assertDecodeVariable( + "dotty.tools.repl.package$$anon$1", + "void dotty$tools$dotc$reporting$UniqueMessagePositions$_setter_$dotty$tools$dotc$reporting$UniqueMessagePositions$$positions_$eq(scala.collection.mutable.HashMap x$0)", + "scala.collection.mutable.HashMap x$0", + 8, + "x$0: HashMap[(SourceFile, Integer), Diagnostic]" + ) + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.QuoteMatcher$", + "scala.collection.immutable.Seq $eq$qmark$eq(dotty.tools.dotc.ast.Trees$Tree scrutinee0, dotty.tools.dotc.ast.Trees$Tree pattern0, scala.collection.immutable.Map x$3, dotty.tools.dotc.core.Contexts$Context x$4, scala.util.boundary$Label evidence$4)", + "dotty.tools.dotc.ast.Trees$Tree s", + 491, + "s: Tree[Types.Type]" + ) + // transparent inline + decoder.assertDecodeVariable( + "dotty.tools.dotc.typer.Namer$TypeDefCompleter", + "dotty.tools.dotc.core.Contexts$Context given_Context$lzyINIT1$1(scala.runtime.LazyRef given_Context$lzy1$1)", + "dotty.tools.dotc.core.Contexts$Context x$proxy4", + 965, + "x$proxy4: Contexts.Context | Null" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.util.HashSet", + "void copyFrom(java.lang.Object[] oldTable)", + "java.lang.Object x$proxy7", + 158, + "x$proxy7: e.type & T" + ) + // inline def in stdLibPatches + decoder.assertDecodeVariable( + "scala.quoted.runtime.impl.QuotesImpl$reflect$defn$", + "dotty.tools.dotc.core.Symbols$Symbol TupleClass(int arity)", + "dotty.tools.dotc.core.Types$TypeRef x$proxy1", + 2816, + "x: T" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.util.SimpleIdentityMap", + "boolean apply$mcZD$sp(double x$0)", + "double x$0", + 8, + "k.: K" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.util.StackTraceOps$", + "java.lang.Object unseen$1$$anonfun$adapted$1(scala.collection.mutable.Set seen$5, java.lang.Throwable e$5, java.lang.Object v1)", + "java.lang.Object v1", + 48, + ": Boolean" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.Driver", + "scala.Tuple2 setup$$anonfun$1(dotty.tools.dotc.core.Contexts$FreshContext ictx$1, scala.collection.immutable.List fileNames)", + "dotty.tools.dotc.core.Contexts$FreshContext ictx$1", + 88, + "ictx.: FreshContext" + ) + decoder.assertDecodeVariable( + "dotty.tools.repl.ReplDriver", + "scala.Tuple2 renderDefinitions$$anonfun$3(dotty.tools.repl.State state$13)", + "dotty.tools.repl.State state$13", + 436, + "state.: State" + ) + decoder.assertDecodeVariable( + "dotty.tools.repl.ReplDriver", + "scala.math.Ordering given_Ordering_Diagnostic$lzyINIT1$1(scala.runtime.LazyRef given_Ordering_Diagnostic$lzy1$1)", + "scala.runtime.LazyRef given_Ordering_Diagnostic$lzy1$1", + 336, + "given_Ordering_Diagnostic.: Ordering[Diagnostic]" + ) + decoder.assertDecodeVariable( + "dotty.tools.repl.ReplCompiler", + "scala.util.Either typeCheck$$anonfun$1(boolean errorsAllowed$2, dotty.tools.dotc.util.SourceFile src$1, dotty.tools.dotc.core.Contexts$FreshContext c$proxy1$1, dotty.tools.dotc.ast.Trees$PackageDef pkg)", + "dotty.tools.dotc.core.Contexts$FreshContext c$proxy1$1", + 211, + "c.: Context" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.typer.VarianceChecker$Validator$", + "scala.Option apply(scala.Option status, dotty.tools.dotc.core.Types$Type tp)", + "dotty.tools.dotc.reporting.trace$ TraceSyntax_this", + 131, + "this: trace.this.type" + ) + decoder.assertDecodeVariable( + "dotty.tools.dotc.typer.Typer", + "java.lang.String prefix$1(dotty.tools.dotc.ast.Trees$Tree res$3, scala.runtime.LazyRef $29$$lzy1$3)", + "scala.runtime.LazyRef $29$$lzy1$3", + 2890, + ".: (String, String)" + ) diff --git a/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala b/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala index 4639075..94a8d88 100644 --- a/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala +++ b/src/test/scala/ch/epfl/scala/decoder/testutils/BinaryDecoderSuite.scala @@ -34,48 +34,53 @@ trait BinaryDecoderSuite extends CommonFunSuite: def showVariables(className: String, method: String): Unit = val variables = loadBinaryMethod(className, method).variables println( - s"Available binary variables in $method are:\n" + variables.map(f => s" " + formatVariable(f)).mkString("\n") + s"Available binary variables in $method are:\n" + variables + .map(v => s" " + formatVariable(v) + " " + v.showSpan) + .mkString("\n") ) def showFields(className: String): Unit = val fields = decoder.classLoader.loadClass(className).declaredFields println(s"Available binary fields in $className are:\n" + fields.map(f => s" " + formatField(f)).mkString("\n")) - def assertDecode(className: String, expected: String)(using munit.Location): Unit = + def assertDecodeClass(className: String, expected: String)(using munit.Location): Unit = val cls = decoder.classLoader.loadClass(className) val decodedClass = decoder.decode(cls) assertEquals(formatter.format(decodedClass), expected) - def assertDecode(className: String, method: String, expected: String, generated: Boolean = false)(using + def assertDecodeMethod(className: String, method: String, expected: String, generated: Boolean = false)(using munit.Location ): Unit = val binaryMethod = loadBinaryMethod(className, method) val decodedMethod = decoder.decode(binaryMethod) assertEquals(formatter.format(decodedMethod), expected) - assertEquals(decodedMethod.isGenerated, generated) + assertEquals(decodedMethod.isGenerated(using decoder.context), generated) - def assertDecodeField(className: String, field: String, expected: String, generated: Boolean = false)(using + def assertDecodeField(className: String, field: String, expected: String)(using munit.Location ): Unit = val binaryField: binary.Field = loadBinaryField(className, field) val decodedField = decoder.decode(binaryField) assertEquals(formatter.format(decodedField), expected) - assertEquals(decodedField.isGenerated, generated) - - def assertDecodeVariable( - className: String, - method: String, - variable: String, - expected: String, - line: Int, - generated: Boolean = false - )(using + + def assertDecodeVariable(className: String, method: String, variable: String, line: Int, expected: String)(using munit.Location ): Unit = val binaryVariable = loadBinaryVariable(className, method, variable) val decodedVariable = decoder.decode(binaryVariable, line) assertEquals(formatter.format(decodedVariable), expected) - assertEquals(decodedVariable.isGenerated, generated) + + def assertNotFoundMethod(declaringType: String, javaSig: String)(using munit.Location): Unit = + val method = loadBinaryMethod(declaringType, javaSig) + intercept[NotFoundException](decoder.decode(method)) + + def assertAmbiguousField(className: String, field: String)(using munit.Location): Unit = + val binaryField: binary.Field = loadBinaryField(className, field) + intercept[AmbiguousException](decoder.decode(binaryField)) + + def assertNotFoundField(className: String, field: String)(using munit.Location): Unit = + val binaryField = loadBinaryField(className, field) + intercept[NotFoundException](decoder.decode(binaryField)) def assertAmbiguousVariable(className: String, method: String, variable: String, line: Int)(using munit.Location @@ -83,33 +88,22 @@ trait BinaryDecoderSuite extends CommonFunSuite: val binaryVariable = loadBinaryVariable(className, method, variable) intercept[AmbiguousException](decoder.decode(binaryVariable, line)) + def decodeVariable(className: String, method: String, variable: String, line: Int): Unit = + val binaryVariable = loadBinaryVariable(className, method, variable) + decoder.decode(binaryVariable, line) + def assertNotFoundVariable(className: String, method: String, variable: String, line: Int)(using munit.Location ): Unit = val binaryVariable = loadBinaryVariable(className, method, variable) intercept[NotFoundException](decoder.decode(binaryVariable, line)) - def assertNoSuchElementVariable(className: String, method: String, variable: String)(using munit.Location): Unit = - intercept[NoSuchElementException](loadBinaryVariable(className, method, variable)) - def assertIgnoredVariable(className: String, method: String, variable: String, reason: String)(using munit.Location ): Unit = val binaryVariable = loadBinaryVariable(className, method, variable) intercept[IgnoredException](decoder.decode(binaryVariable, 0)) - def assertAmbiguousField(className: String, field: String)(using munit.Location): Unit = - val binaryField: binary.Field = loadBinaryField(className, field) - intercept[AmbiguousException](decoder.decode(binaryField)) - - def assertNotFoundField(className: String, field: String)(using munit.Location): Unit = - val binaryField = loadBinaryField(className, field) - intercept[NotFoundException](decoder.decode(binaryField)) - - def assertNotFound(declaringType: String, javaSig: String)(using munit.Location): Unit = - val method = loadBinaryMethod(declaringType, javaSig) - intercept[NotFoundException](decoder.decode(method)) - def assertDecodeAllInClass( className: String )(expectedMethods: ExpectedCount = ExpectedCount(0), printProgress: Boolean = false)(using munit.Location): Unit = @@ -131,36 +125,39 @@ trait BinaryDecoderSuite extends CommonFunSuite: expectedMethods: ExpectedCount = ExpectedCount(0), expectedFields: ExpectedCount = ExpectedCount(0), expectedVariables: ExpectedCount = ExpectedCount(0), - printProgress: Boolean = false + printProgress: Boolean = false, + classFilter: Set[String] = Set.empty )(using munit.Location): Unit = - val (classCounter, methodCounter, fieldCounter, variableCounter) = decodeAll(printProgress) - if classCounter.throwables.nonEmpty then - classCounter.printThrowables() - classCounter.printThrowable(0) - else if methodCounter.throwables.nonEmpty then - methodCounter.printThrowables() - methodCounter.printThrowable(0) + val (classCounter, methodCounter, fieldCounter, variableCounter) = decodeAll(printProgress, classFilter) + if classCounter.throwables.nonEmpty then classCounter.printThrowable(0) + else if methodCounter.throwables.nonEmpty then methodCounter.printThrowable(0) else if variableCounter.throwables.nonEmpty then variableCounter.printThrowable(0) - // variableCounter.printNotFound(40) + // methodCounter.printNotFound(40) + variableCounter.printNotFound(40) classCounter.check(expectedClasses) methodCounter.check(expectedMethods) fieldCounter.check(expectedFields) variableCounter.check(expectedVariables) - def decodeAll(printProgress: Boolean = false): (Counter, Counter, Counter, Counter) = + def decodeAll( + printProgress: Boolean = false, + classFilter: Set[String] = Set.empty + ): (Counter, Counter, Counter, Counter) = val classCounter = Counter(decoder.name + " classes") val methodCounter = Counter(decoder.name + " methods") val fieldCounter = Counter(decoder.name + " fields") val variableCounter = Counter(decoder.name + " variables") for binaryClass <- decoder.allClasses + if classFilter.isEmpty || classFilter.contains(binaryClass.name) _ = if printProgress then println(s"\"${binaryClass.name}\"") decodedClass <- decoder.tryDecode(binaryClass, classCounter) _ = binaryClass.declaredFields.foreach(f => decoder.tryDecode(decodedClass, f, fieldCounter)) binaryMethod <- binaryClass.declaredMethods decodedMethod <- decoder.tryDecode(decodedClass, binaryMethod, methodCounter) binaryVariable <- binaryMethod.variables - do decoder.tryDecode(decodedMethod, binaryVariable, variableCounter) + line <- debugLine(binaryVariable) + do decoder.tryDecode(decodedMethod, binaryVariable, line, variableCounter) classCounter.printReport() methodCounter.printReport() fieldCounter.printReport() @@ -200,7 +197,7 @@ trait BinaryDecoderSuite extends CommonFunSuite: .find(v => formatVariable(v) == variableName) .getOrElse(throw new NoSuchElementException(notFoundMessage)) - private def tryDecode(cls: binary.ClassType, counter: Counter): Option[DecodedClass] = + private def tryDecode(cls: binary.BinaryClass, counter: Counter): Option[DecodedClass] = try val sym = decoder.decode(cls) counter.success += (cls -> sym) @@ -245,9 +242,9 @@ trait BinaryDecoderSuite extends CommonFunSuite: case ignored: IgnoredException => counter.ignored += ignored case e => counter.throwables += (field -> e) - private def tryDecode(mtd: DecodedMethod, variable: binary.Variable, counter: Counter): Unit = + private def tryDecode(mtd: DecodedMethod, variable: binary.Variable, line: Int, counter: Counter): Unit = try - val decoded = decoder.decode(mtd, variable, variable.sourceLines.get.lines.head) + val decoded = decoder.decode(mtd, variable, line) counter.success += (variable -> decoded) catch case notFound: NotFoundException => counter.notFound += (variable -> notFound) @@ -256,20 +253,23 @@ trait BinaryDecoderSuite extends CommonFunSuite: case e => counter.throwables += (variable -> e) end extension + private def debugLine(variable: binary.Variable): Option[Int] = + // It's tricky to guess a valid debug line. But in practice, it seems the middle one is a good guess. + variable.sourceLines.map(_.lines).filter(_.nonEmpty).map(lines => lines(lines.size / 2)) + private def formatDebug(m: binary.Symbol): String = m match - case f: binary.Field => s"\"${f.declaringClass}\", \"${formatField(f)}\"" - case m: binary.Method => s"\"${m.declaringClass.name}\", \"${formatMethod(m)}\"" + case f: binary.Field => s"(\"${f.declaringClass}\", \"${formatField(f)}\")" + case m: binary.Method => s"(\"${m.declaringClass.name}\", \"${formatMethod(m)}\")" case v: binary.Variable => - s"\"${v.declaringMethod.declaringClass.name}\",\n \"${formatMethod( - v.declaringMethod - )}\",\n \"${formatVariable(v)}\",\n \"\",\n ${v.sourceLines.get.lines.head}, " - // s"\"${v.showSpan}" + s"""("${v.declaringMethod.declaringClass.name}", "${formatMethod(v.declaringMethod)}", "${formatVariable( + v + )}", ${debugLine(v).get})""".stripMargin case cls => s"\"${cls.name}\"" private def formatMethod(m: binary.Method): String = val returnType = m.returnType.map(_.name).get - val parameters = m.allParameters.map(p => p.`type`.name + " " + p.name).mkString(", ") + val parameters = m.parameters.map(p => p.`type`.name + " " + p.name).mkString(", ") s"$returnType ${m.name}($parameters)" private def formatField(f: binary.Field): String = @@ -326,7 +326,7 @@ trait BinaryDecoderSuite extends CommonFunSuite: def printComparisionWithJavaFormatting(): Unit = def formatJavaStyle(m: binary.Method): String = - s"${m.declaringClass.name}.${m.name}(${m.allParameters.map(_.`type`.name).mkString(",")})" + s"${m.declaringClass.name}.${m.name}(${m.parameters.map(_.`type`.name).mkString(",")})" val formatted = success .collect { case (m: binary.Method, d: DecodedMethod) => (formatJavaStyle(m), formatter.format(d)) } @@ -337,10 +337,7 @@ trait BinaryDecoderSuite extends CommonFunSuite: println(s"mean: ${formatted.map((j, s) => s.size - j.size).sum / formatted.size}") end printComparisionWithJavaFormatting - def printSuccess() = - success.foreach { (s, d) => - println(s"${formatDebug(s)}: $d") - } + def printSuccess() = success.foreach((s, _) => println(formatDebug(s))) def printNotFound() = notFound.foreach { case (s1, NotFoundException(s2, _)) => @@ -362,7 +359,8 @@ trait BinaryDecoderSuite extends CommonFunSuite: def printNotFound(n: Int) = notFound.take(n).foreach { case (s1, NotFoundException(s2, owner)) => if s1 != s2 then println(s"${formatDebug(s1)} not found because of ${formatDebug(s2)}") - else println(s"- ${formatDebug(s1)} not found " + owner.map(o => s"in ${o.getClass.getSimpleName()}")) + else + println(s"- ${formatDebug(s1)} not found " + owner.map(o => s"in ${o.getClass.getSimpleName}").getOrElse("")) println("") } diff --git a/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala b/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala index 98f6b4a..a727863 100644 --- a/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala +++ b/src/test/scala/ch/epfl/scala/decoder/testutils/TestingDecoder.scala @@ -14,6 +14,15 @@ import java.nio.file.FileSystem object TestingDecoder: def javaRuntime = JavaRuntime(Properties.jdkHome).get + def apply(sources: Seq[String], scalaVersion: ScalaVersion)(using ThrowOrWarn): TestingDecoder = + val module = Module.fromSources( + sources.zipWithIndex.map((s, i) => s"Test$i.scala" -> s), + scalaVersion, + Seq.empty, + Seq.empty + ) + TestingDecoder(module.mainEntry, module.classpath) + def apply(source: String, scalaVersion: ScalaVersion)(using ThrowOrWarn): TestingDecoder = val module = Module.fromSource(source, scalaVersion) TestingDecoder(module.mainEntry, module.classpath) @@ -43,7 +52,7 @@ class TestingDecoder(mainEntry: ClasspathEntry, val classLoader: BinaryClassLoad val binaryClass = classLoader.loadClass(cls) decode(binaryClass) def name: String = mainEntry.name - def allClasses: Seq[binary.ClassType] = + def allClasses: Seq[binary.BinaryClass] = def listClassNames(root: Path): Seq[String] = val classMatcher = root.getFileSystem.getPathMatcher("glob:**.class") Files