Skip to content

Decode fields #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 86 additions & 18 deletions src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import tastyquery.jdk.ClasspathLoaders

import java.nio.file.Path
import scala.util.matching.Regex
import tastyquery.Exceptions.NonMethodReferenceException

object BinaryDecoder:
def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder =
Expand Down Expand Up @@ -107,6 +108,75 @@ class BinaryDecoder(using Context, ThrowOrWarn):
candidates.singleOrThrow(method)
end decode

def decode(field: binary.Field): DecodedField =
val decodedClass = decode(field.declaringClass)
decode(decodedClass, field)

def decode(decodedClass: DecodedClass, field: binary.Field): DecodedField =
def tryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
f.applyOrElse(field, _ => Seq.empty[DecodedField])

extension (xs: Seq[DecodedField])
def orTryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
if xs.nonEmpty then xs else f.applyOrElse(field, _ => Seq.empty[DecodedField])
val decodedFields =
tryDecode {
case Patterns.LazyVal(name) =>
for
owner <- decodedClass.classSymbol.toSeq ++ decodedClass.linearization.filter(_.isTrait)
sym <- owner.declarations.collect {
case sym: TermSymbol if sym.nameStr == name && sym.isModuleOrLazyVal => sym
}
yield DecodedField.ValDef(decodedClass, sym)
case Patterns.Module() =>
decodedClass.classSymbol.flatMap(_.moduleValue).map(DecodedField.ModuleVal(decodedClass, _)).toSeq
case Patterns.Offset(nbr) =>
Seq(DecodedField.LazyValOffset(decodedClass, nbr, defn.LongType))
case Patterns.OuterField() =>
decodedClass.symbolOpt
.flatMap(_.outerClass)
.map(outerClass => DecodedField.Outer(decodedClass, outerClass.selfType))
.toSeq
case Patterns.SerialVersionUID() =>
Seq(DecodedField.SerialVersionUID(decodedClass, defn.LongType))
case Patterns.LazyValBitmap(name) =>
Seq(DecodedField.LazyValBitmap(decodedClass, defn.BooleanType, name))
case Patterns.AnyValCapture() =>
for
classSym <- decodedClass.symbolOpt.toSeq
outerClass <- classSym.outerClass.toSeq
if outerClass.isSubClass(defn.AnyValClass)
sym <- outerClass.declarations.collect {
case sym: TermSymbol if sym.isVal && !sym.isMethod => sym
}
yield DecodedField.Capture(decodedClass, sym)
case Patterns.Capture(names) =>
decodedClass.symbolOpt.toSeq
.flatMap(CaptureCollector.collectCaptures)
.filter { captureSym =>
names.exists {
case Patterns.LazyVal(name) => name == captureSym.nameStr
case name => name == captureSym.nameStr
}
}
.map(DecodedField.Capture(decodedClass, _))

case _ if field.isStatic && decodedClass.isJava =>
for
owner <- decodedClass.companionClassSymbol.toSeq
sym <- owner.declarations.collect { case sym: TermSymbol if sym.nameStr == field.name => sym }
yield DecodedField.ValDef(decodedClass, sym)
}.orTryDecode { case _ =>
for
owner <- withCompanionIfExtendsJavaLangEnum(decodedClass) ++ decodedClass.linearization.filter(_.isTrait)
sym <- owner.declarations.collect {
case sym: TermSymbol if matchTargetName(field, sym) && !sym.isMethod => sym
}
yield DecodedField.ValDef(decodedClass, sym)
}
decodedFields.singleOrThrow(field)
end decode

private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] =
if syms.size > 1 then
val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _)))
Expand Down Expand Up @@ -476,13 +546,8 @@ class BinaryDecoder(using Context, ThrowOrWarn):
.map(target => DecodedMethod.TraitStaticForwarder(decode(decodedClass, target)))

private def decodeOuter(decodedClass: DecodedClass): Option[DecodedMethod.OuterAccessor] =
def outerClass(sym: Symbol): Option[ClassSymbol] =
sym.owner match
case null => None
case owner if owner.isClass => Some(owner.asClass)
case owner => outerClass(owner)
decodedClass.symbolOpt
.flatMap(outerClass)
.flatMap(_.outerClass)
.map(outerClass => DecodedMethod.OuterAccessor(decodedClass, outerClass.thisType))

private def decodeTraitInitializer(
Expand Down Expand Up @@ -616,11 +681,18 @@ class BinaryDecoder(using Context, ThrowOrWarn):
DecodedMethod.MixinForwarder(decodedClass, staticForwarder.target)
}

private def withCompanionIfExtendsAnyVal(cls: ClassSymbol): Seq[ClassSymbol] =
cls.companionClass match
case Some(companionClass) if companionClass.isSubClass(defn.AnyValClass) =>
Seq(cls, companionClass)
case _ => Seq(cls)
private def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
case classDef: DecodedClass.ClassDef =>
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)

private def withCompanionIfExtendsJavaLangEnum(decodedClass: DecodedClass): Seq[ClassSymbol] =
decodedClass.classSymbol.toSeq.flatMap { cls =>
if cls.isSubClass(defn.javaLangEnumClass) then Seq(cls) ++ cls.companionClass
else Seq(cls)
}

private def decodeAdaptedAnonFun(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] =
if method.instructions.nonEmpty then
Expand Down Expand Up @@ -786,13 +858,6 @@ class BinaryDecoder(using Context, ThrowOrWarn):
private def collectLiftedTrees[S](decodedClass: DecodedClass, method: binary.Method)(
matcher: PartialFunction[LiftedTree[?], LiftedTree[S]]
): Seq[LiftedTree[S]] =
def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
case classDef: DecodedClass.ClassDef =>
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)

val owners = withCompanionIfExtendsAnyVal(decodedClass)
val sourceLines =
if owners.size == 2 && method.allParameters.exists(p => p.name.matches("\\$this\\$\\d+")) then
Expand Down Expand Up @@ -823,6 +888,9 @@ class BinaryDecoder(using Context, ThrowOrWarn):
private def matchTargetName(method: binary.Method, symbol: TermSymbol): Boolean =
method.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)

private def matchTargetName(field: binary.Field, symbol: TermSymbol): Boolean =
field.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)

private case class SourceParams(
declaredParamNames: Seq[UnsignedTermName],
declaredParamTypes: Seq[Type],
Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,35 @@ object DecodedMethod:
override def toString: String =
if underlying.isInstanceOf[InlinedMethod] then underlying.toString
else s"$underlying (inlined)"

sealed trait DecodedField extends DecodedSymbol:
def owner: DecodedClass
override def symbolOpt: Option[TermSymbol] = None
def declaredType: TypeOrMethodic

object DecodedField:
final class ValDef(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"ValDef($owner, ${symbol.showBasic})"

final class ModuleVal(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def symbolOpt: Option[TermSymbol] = Some(symbol)
override def toString: String = s"ModuleVal($owner, ${symbol.showBasic})"

final class LazyValOffset(val owner: DecodedClass, val index: Int, val declaredType: Type) extends DecodedField:
override def toString: String = s"LazyValOffset($owner, $index)"

final class Outer(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
override def toString: String = s"Outer($owner, ${declaredType.showBasic})"

final class SerialVersionUID(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
override def toString: String = s"SerialVersionUID($owner)"

final class Capture(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
def declaredType: TypeOrMethodic = symbol.declaredType
override def toString: String = s"Capture($owner, ${symbol.showBasic})"

final class LazyValBitmap(val owner: DecodedClass, val declaredType: Type, val name: String) extends DecodedField:
override def toString: String = s"LazyValBitmap($owner, , ${declaredType.showBasic})"
19 changes: 19 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import tastyquery.Types.*
import scala.annotation.tailrec

class StackTraceFormatter(using ThrowOrWarn):
def format(field: DecodedField): String =
val typeAscription = field.declaredType match
case tpe: Type => ": " + format(tpe)
case tpe => format(tpe)
formatOwner(field).dot(formatName(field)) + typeAscription

def format(cls: DecodedClass): String =
cls match
case cls: DecodedClass.ClassDef => formatQualifiedName(cls.symbol)
Expand Down Expand Up @@ -60,6 +66,19 @@ class StackTraceFormatter(using ThrowOrWarn):
case method: DecodedMethod.SAMOrPartialFunctionConstructor => format(method.owner)
case method: DecodedMethod.InlinedMethod => formatOwner(method.underlying)

private def formatOwner(field: DecodedField): String =
format(field.owner)

private def formatName(field: DecodedField): String =
field match
case field: DecodedField.ValDef => formatName(field.symbol)
case field: DecodedField.ModuleVal => ""
case field: DecodedField.LazyValOffset => "<offset " + field.index + ">"
case field: DecodedField.Outer => "<outer>"
case field: DecodedField.SerialVersionUID => "<serialVersionUID>"
case field: DecodedField.Capture => formatName(field.symbol).dot("<capture>")
case field: DecodedField.LazyValBitmap => field.name.dot("<lazy val bitmap>")

private def formatName(method: DecodedMethod): String =
method match
case method: DecodedMethod.ValOrDefDef => formatName(method.symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ trait ClassType extends Type:
def declaredField(name: String): Option[Field]
def declaredMethod(name: String, descriptor: String): Option[Method]
def declaredMethods: Seq[Method]
def declaredFields: Seq[Field]
def classLoader: BinaryClassLoader

def isObject = name.endsWith("$")
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/ch/epfl/scala/decoder/binary/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ package ch.epfl.scala.decoder.binary
trait Field extends Symbol:
def declaringClass: ClassType
def `type`: Type
def isStatic: Boolean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ch.epfl.scala.decoder.internal

import tastyquery.Trees.*
import scala.collection.mutable
import tastyquery.Symbols.*
import tastyquery.Traversers.*
import tastyquery.Contexts.*
import tastyquery.SourcePosition
import tastyquery.Types.*
import tastyquery.Traversers
import ch.epfl.scala.decoder.ThrowOrWarn
import scala.languageFeature.postfixOps

object CaptureCollector:
def collectCaptures(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn): Set[TermSymbol] =
val collector = CaptureCollector(cls)
collector.traverse(cls.tree)
collector.capture.toSet

class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn) extends TreeTraverser:
val capture: mutable.Set[TermSymbol] = mutable.Set.empty
val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty

def loopCollect(symbol: Symbol)(collect: => Unit): Unit =
if !alreadySeen.contains(symbol) then
alreadySeen += symbol
collect
override def traverse(tree: Tree): Unit =
tree match
case _: TypeTree => ()
case ident: Ident =>
for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do
// check that sym is local
// and check that no owners of sym is cls
if !alreadySeen.contains(sym) then
if sym.isLocal then
if !ownersIsCls(sym) then capture += sym
if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse))
else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse))
case _ => super.traverse(tree)

def ownersIsCls(sym: Symbol): Boolean =
sym.owner match
case owner: Symbol =>
if owner == cls then true
else ownersIsCls(owner)
case null => false
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Definitions(using ctx: Context):
val PartialFunctionClass = scalaPackage.getDecl(typeName("PartialFunction")).get.asClass
val AbstractPartialFunctionClass = scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass
val SerializableClass = javaIoPackage.getDecl(typeName("Serializable")).get.asClass
val javaLangEnumClass = javaLangPackage.getDecl(typeName("Enum")).get.asClass

val SerializedLambdaType: Type = TypeRef(javaLangInvokePackage.packageRef, typeName("SerializedLambda"))
val DeserializeLambdaType = MethodType(List(SimpleName("arg0")), List(SerializedLambdaType), ObjectType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import tastyquery.Contexts.*
import tastyquery.SourcePosition
import tastyquery.Types.*
import tastyquery.Traversers
import tastyquery.Exceptions.NonMethodReferenceException
import ch.epfl.scala.decoder.ThrowOrWarn

/**
Expand Down
39 changes: 39 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,45 @@ object Patterns:
"(.+)\\$i\\d+".r.unapplySeq(xs(0)).map(_(0)).getOrElse(xs(0))
}

object LazyVal:
def unapply(field: binary.Field): Option[String] = unapply(field.decodedName)

def unapply(name: String): Option[String] =
"""(.*)\$lzy\d+""".r.unapplySeq(name).map(xs => xs(0).stripSuffix("$"))

object Module:
def unapply(field: binary.Field): Boolean = field.name == "MODULE$"

object Offset:
def unapply(field: binary.Field): Option[Int] =
"""OFFSET\$(?:_m_)?(\d+)""".r.unapplySeq(field.name).map(xs => xs(0).toInt)

object OuterField:
def unapply(field: binary.Field): Boolean = field.name == "$outer"

object SerialVersionUID:
def unapply(field: binary.Field): Boolean = field.name == "serialVersionUID"

object AnyValCapture:
def unapply(field: binary.Field): Boolean =
field.name.matches("\\$this\\$\\d+")

object Capture:
def unapply(field: binary.Field): Option[Seq[String]] =
field.extractFromDecodedNames("(.+)\\$\\d+".r)(xs => xs(0))

object LazyValBitmap:
def unapply(field: binary.Field): Option[String] =
"(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0))

extension (field: binary.Field)
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
val extracted = field.unexpandedDecodedNames
.flatMap(regex.unapplySeq)
.map(extract)
.distinct
if extracted.nonEmpty then Some(extracted) else None

extension (method: binary.Method)
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
val extracted = method.unexpandedDecodedNames
Expand Down
17 changes: 17 additions & 0 deletions src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ extension (symbol: Symbol)
def isInline = symbol.isTerm && symbol.asTerm.isInline
def nameStr: String = symbol.name.toString

def outerClass: Option[ClassSymbol] =
symbol.owner match
case null => None
case owner: ClassSymbol => Some(owner)
case owner => owner.outerClass

def showBasic =
val span = symbol.tree.map(_.pos) match
case Some(pos) if pos.isFullyDefined =>
Expand Down Expand Up @@ -307,3 +313,14 @@ extension (method: DecodedMethod)
case _: DecodedMethod.SAMOrPartialFunctionConstructor => true
case method: DecodedMethod.InlinedMethod => method.underlying.isGenerated
case _ => false

extension (field: DecodedField)
def isGenerated: Boolean =
field match
case field: DecodedField.ValDef => false
case field: DecodedField.ModuleVal => true
case field: DecodedField.LazyValOffset => true
case field: DecodedField.Outer => true
case field: DecodedField.SerialVersionUID => true
case field: DecodedField.Capture => true
case field: DecodedField.LazyValBitmap => true
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ class JavaReflectClass(cls: Class[?], extraInfo: ExtraClassInfo, override val cl
val methodInfo = extraInfo.getMethodInfo(sig)
JavaReflectConstructor(c, sig, methodInfo, classLoader)
}

override def declaredFields: Seq[binary.Field] =
cls.getDeclaredFields().map(f => JavaReflectField(f, classLoader))
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ch.epfl.scala.decoder.javareflect
import ch.epfl.scala.decoder.binary

import java.lang.reflect.Field
import java.lang.reflect.Modifier

class JavaReflectField(field: Field, loader: JavaReflectLoader) extends binary.Field:
override def name: String = field.getName
Expand All @@ -11,5 +12,9 @@ class JavaReflectField(field: Field, loader: JavaReflectLoader) extends binary.F

override def declaringClass: binary.ClassType = loader.loadClass(field.getDeclaringClass)

override def isStatic: Boolean = Modifier.isStatic(field.getModifiers)

override def `type`: binary.Type =
loader.loadClass(field.getType)

override def toString: String = field.toString
Loading