Skip to content

Commit df373c1

Browse files
authored
Merge pull request #2 from SaadAissa/decodeFields
Decode fields
2 parents 01d744e + ec47589 commit df373c1

17 files changed

+736
-43
lines changed

src/main/scala/ch/epfl/scala/decoder/BinaryDecoder.scala

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import tastyquery.jdk.ClasspathLoaders
1414

1515
import java.nio.file.Path
1616
import scala.util.matching.Regex
17+
import tastyquery.Exceptions.NonMethodReferenceException
1718

1819
object BinaryDecoder:
1920
def apply(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder =
@@ -107,6 +108,75 @@ class BinaryDecoder(using Context, ThrowOrWarn):
107108
candidates.singleOrThrow(method)
108109
end decode
109110

111+
def decode(field: binary.Field): DecodedField =
112+
val decodedClass = decode(field.declaringClass)
113+
decode(decodedClass, field)
114+
115+
def decode(decodedClass: DecodedClass, field: binary.Field): DecodedField =
116+
def tryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
117+
f.applyOrElse(field, _ => Seq.empty[DecodedField])
118+
119+
extension (xs: Seq[DecodedField])
120+
def orTryDecode(f: PartialFunction[binary.Field, Seq[DecodedField]]): Seq[DecodedField] =
121+
if xs.nonEmpty then xs else f.applyOrElse(field, _ => Seq.empty[DecodedField])
122+
val decodedFields =
123+
tryDecode {
124+
case Patterns.LazyVal(name) =>
125+
for
126+
owner <- decodedClass.classSymbol.toSeq ++ decodedClass.linearization.filter(_.isTrait)
127+
sym <- owner.declarations.collect {
128+
case sym: TermSymbol if sym.nameStr == name && sym.isModuleOrLazyVal => sym
129+
}
130+
yield DecodedField.ValDef(decodedClass, sym)
131+
case Patterns.Module() =>
132+
decodedClass.classSymbol.flatMap(_.moduleValue).map(DecodedField.ModuleVal(decodedClass, _)).toSeq
133+
case Patterns.Offset(nbr) =>
134+
Seq(DecodedField.LazyValOffset(decodedClass, nbr, defn.LongType))
135+
case Patterns.OuterField() =>
136+
decodedClass.symbolOpt
137+
.flatMap(_.outerClass)
138+
.map(outerClass => DecodedField.Outer(decodedClass, outerClass.selfType))
139+
.toSeq
140+
case Patterns.SerialVersionUID() =>
141+
Seq(DecodedField.SerialVersionUID(decodedClass, defn.LongType))
142+
case Patterns.LazyValBitmap(name) =>
143+
Seq(DecodedField.LazyValBitmap(decodedClass, defn.BooleanType, name))
144+
case Patterns.AnyValCapture() =>
145+
for
146+
classSym <- decodedClass.symbolOpt.toSeq
147+
outerClass <- classSym.outerClass.toSeq
148+
if outerClass.isSubClass(defn.AnyValClass)
149+
sym <- outerClass.declarations.collect {
150+
case sym: TermSymbol if sym.isVal && !sym.isMethod => sym
151+
}
152+
yield DecodedField.Capture(decodedClass, sym)
153+
case Patterns.Capture(names) =>
154+
decodedClass.symbolOpt.toSeq
155+
.flatMap(CaptureCollector.collectCaptures)
156+
.filter { captureSym =>
157+
names.exists {
158+
case Patterns.LazyVal(name) => name == captureSym.nameStr
159+
case name => name == captureSym.nameStr
160+
}
161+
}
162+
.map(DecodedField.Capture(decodedClass, _))
163+
164+
case _ if field.isStatic && decodedClass.isJava =>
165+
for
166+
owner <- decodedClass.companionClassSymbol.toSeq
167+
sym <- owner.declarations.collect { case sym: TermSymbol if sym.nameStr == field.name => sym }
168+
yield DecodedField.ValDef(decodedClass, sym)
169+
}.orTryDecode { case _ =>
170+
for
171+
owner <- withCompanionIfExtendsJavaLangEnum(decodedClass) ++ decodedClass.linearization.filter(_.isTrait)
172+
sym <- owner.declarations.collect {
173+
case sym: TermSymbol if matchTargetName(field, sym) && !sym.isMethod => sym
174+
}
175+
yield DecodedField.ValDef(decodedClass, sym)
176+
}
177+
decodedFields.singleOrThrow(field)
178+
end decode
179+
110180
private def reduceAmbiguityOnClasses(syms: Seq[DecodedClass]): Seq[DecodedClass] =
111181
if syms.size > 1 then
112182
val reduced = syms.filterNot(sym => syms.exists(enclose(sym, _)))
@@ -476,13 +546,8 @@ class BinaryDecoder(using Context, ThrowOrWarn):
476546
.map(target => DecodedMethod.TraitStaticForwarder(decode(decodedClass, target)))
477547

478548
private def decodeOuter(decodedClass: DecodedClass): Option[DecodedMethod.OuterAccessor] =
479-
def outerClass(sym: Symbol): Option[ClassSymbol] =
480-
sym.owner match
481-
case null => None
482-
case owner if owner.isClass => Some(owner.asClass)
483-
case owner => outerClass(owner)
484549
decodedClass.symbolOpt
485-
.flatMap(outerClass)
550+
.flatMap(_.outerClass)
486551
.map(outerClass => DecodedMethod.OuterAccessor(decodedClass, outerClass.thisType))
487552

488553
private def decodeTraitInitializer(
@@ -616,11 +681,18 @@ class BinaryDecoder(using Context, ThrowOrWarn):
616681
DecodedMethod.MixinForwarder(decodedClass, staticForwarder.target)
617682
}
618683

619-
private def withCompanionIfExtendsAnyVal(cls: ClassSymbol): Seq[ClassSymbol] =
620-
cls.companionClass match
621-
case Some(companionClass) if companionClass.isSubClass(defn.AnyValClass) =>
622-
Seq(cls, companionClass)
623-
case _ => Seq(cls)
684+
private def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
685+
case classDef: DecodedClass.ClassDef =>
686+
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
687+
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
688+
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
689+
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)
690+
691+
private def withCompanionIfExtendsJavaLangEnum(decodedClass: DecodedClass): Seq[ClassSymbol] =
692+
decodedClass.classSymbol.toSeq.flatMap { cls =>
693+
if cls.isSubClass(defn.javaLangEnumClass) then Seq(cls) ++ cls.companionClass
694+
else Seq(cls)
695+
}
624696

625697
private def decodeAdaptedAnonFun(decodedClass: DecodedClass, method: binary.Method): Seq[DecodedMethod] =
626698
if method.instructions.nonEmpty then
@@ -786,13 +858,6 @@ class BinaryDecoder(using Context, ThrowOrWarn):
786858
private def collectLiftedTrees[S](decodedClass: DecodedClass, method: binary.Method)(
787859
matcher: PartialFunction[LiftedTree[?], LiftedTree[S]]
788860
): Seq[LiftedTree[S]] =
789-
def withCompanionIfExtendsAnyVal(decodedClass: DecodedClass): Seq[Symbol] = decodedClass match
790-
case classDef: DecodedClass.ClassDef =>
791-
Seq(classDef.symbol) ++ classDef.symbol.companionClass.filter(_.isSubClass(defn.AnyValClass))
792-
case _: DecodedClass.SyntheticCompanionClass => Seq.empty
793-
case anonFun: DecodedClass.SAMOrPartialFunction => Seq(anonFun.symbol)
794-
case inlined: DecodedClass.InlinedClass => withCompanionIfExtendsAnyVal(inlined.underlying)
795-
796861
val owners = withCompanionIfExtendsAnyVal(decodedClass)
797862
val sourceLines =
798863
if owners.size == 2 && method.allParameters.exists(p => p.name.matches("\\$this\\$\\d+")) then
@@ -823,6 +888,9 @@ class BinaryDecoder(using Context, ThrowOrWarn):
823888
private def matchTargetName(method: binary.Method, symbol: TermSymbol): Boolean =
824889
method.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)
825890

891+
private def matchTargetName(field: binary.Field, symbol: TermSymbol): Boolean =
892+
field.unexpandedDecodedNames.map(_.stripSuffix("$")).contains(symbol.targetNameStr)
893+
826894
private case class SourceParams(
827895
declaredParamNames: Seq[UnsignedTermName],
828896
declaredParamTypes: Seq[Type],

src/main/scala/ch/epfl/scala/decoder/DecodedSymbol.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,35 @@ object DecodedMethod:
160160
override def toString: String =
161161
if underlying.isInstanceOf[InlinedMethod] then underlying.toString
162162
else s"$underlying (inlined)"
163+
164+
sealed trait DecodedField extends DecodedSymbol:
165+
def owner: DecodedClass
166+
override def symbolOpt: Option[TermSymbol] = None
167+
def declaredType: TypeOrMethodic
168+
169+
object DecodedField:
170+
final class ValDef(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
171+
def declaredType: TypeOrMethodic = symbol.declaredType
172+
override def symbolOpt: Option[TermSymbol] = Some(symbol)
173+
override def toString: String = s"ValDef($owner, ${symbol.showBasic})"
174+
175+
final class ModuleVal(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
176+
def declaredType: TypeOrMethodic = symbol.declaredType
177+
override def symbolOpt: Option[TermSymbol] = Some(symbol)
178+
override def toString: String = s"ModuleVal($owner, ${symbol.showBasic})"
179+
180+
final class LazyValOffset(val owner: DecodedClass, val index: Int, val declaredType: Type) extends DecodedField:
181+
override def toString: String = s"LazyValOffset($owner, $index)"
182+
183+
final class Outer(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
184+
override def toString: String = s"Outer($owner, ${declaredType.showBasic})"
185+
186+
final class SerialVersionUID(val owner: DecodedClass, val declaredType: Type) extends DecodedField:
187+
override def toString: String = s"SerialVersionUID($owner)"
188+
189+
final class Capture(val owner: DecodedClass, val symbol: TermSymbol) extends DecodedField:
190+
def declaredType: TypeOrMethodic = symbol.declaredType
191+
override def toString: String = s"Capture($owner, ${symbol.showBasic})"
192+
193+
final class LazyValBitmap(val owner: DecodedClass, val declaredType: Type, val name: String) extends DecodedField:
194+
override def toString: String = s"LazyValBitmap($owner, , ${declaredType.showBasic})"

src/main/scala/ch/epfl/scala/decoder/StackTraceFormatter.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ import tastyquery.Types.*
1010
import scala.annotation.tailrec
1111

1212
class StackTraceFormatter(using ThrowOrWarn):
13+
def format(field: DecodedField): String =
14+
val typeAscription = field.declaredType match
15+
case tpe: Type => ": " + format(tpe)
16+
case tpe => format(tpe)
17+
formatOwner(field).dot(formatName(field)) + typeAscription
18+
1319
def format(cls: DecodedClass): String =
1420
cls match
1521
case cls: DecodedClass.ClassDef => formatQualifiedName(cls.symbol)
@@ -60,6 +66,19 @@ class StackTraceFormatter(using ThrowOrWarn):
6066
case method: DecodedMethod.SAMOrPartialFunctionConstructor => format(method.owner)
6167
case method: DecodedMethod.InlinedMethod => formatOwner(method.underlying)
6268

69+
private def formatOwner(field: DecodedField): String =
70+
format(field.owner)
71+
72+
private def formatName(field: DecodedField): String =
73+
field match
74+
case field: DecodedField.ValDef => formatName(field.symbol)
75+
case field: DecodedField.ModuleVal => ""
76+
case field: DecodedField.LazyValOffset => "<offset " + field.index + ">"
77+
case field: DecodedField.Outer => "<outer>"
78+
case field: DecodedField.SerialVersionUID => "<serialVersionUID>"
79+
case field: DecodedField.Capture => formatName(field.symbol).dot("<capture>")
80+
case field: DecodedField.LazyValBitmap => field.name.dot("<lazy val bitmap>")
81+
6382
private def formatName(method: DecodedMethod): String =
6483
method match
6584
case method: DecodedMethod.ValOrDefDef => formatName(method.symbol)

src/main/scala/ch/epfl/scala/decoder/binary/ClassType.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ trait ClassType extends Type:
99
def declaredField(name: String): Option[Field]
1010
def declaredMethod(name: String, descriptor: String): Option[Method]
1111
def declaredMethods: Seq[Method]
12+
def declaredFields: Seq[Field]
1213
def classLoader: BinaryClassLoader
1314

1415
def isObject = name.endsWith("$")

src/main/scala/ch/epfl/scala/decoder/binary/Field.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ package ch.epfl.scala.decoder.binary
33
trait Field extends Symbol:
44
def declaringClass: ClassType
55
def `type`: Type
6+
def isStatic: Boolean
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package ch.epfl.scala.decoder.internal
2+
3+
import tastyquery.Trees.*
4+
import scala.collection.mutable
5+
import tastyquery.Symbols.*
6+
import tastyquery.Traversers.*
7+
import tastyquery.Contexts.*
8+
import tastyquery.SourcePosition
9+
import tastyquery.Types.*
10+
import tastyquery.Traversers
11+
import ch.epfl.scala.decoder.ThrowOrWarn
12+
import scala.languageFeature.postfixOps
13+
14+
object CaptureCollector:
15+
def collectCaptures(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn): Set[TermSymbol] =
16+
val collector = CaptureCollector(cls)
17+
collector.traverse(cls.tree)
18+
collector.capture.toSet
19+
20+
class CaptureCollector(cls: ClassSymbol | TermSymbol)(using Context, ThrowOrWarn) extends TreeTraverser:
21+
val capture: mutable.Set[TermSymbol] = mutable.Set.empty
22+
val alreadySeen: mutable.Set[Symbol] = mutable.Set.empty
23+
24+
def loopCollect(symbol: Symbol)(collect: => Unit): Unit =
25+
if !alreadySeen.contains(symbol) then
26+
alreadySeen += symbol
27+
collect
28+
override def traverse(tree: Tree): Unit =
29+
tree match
30+
case _: TypeTree => ()
31+
case ident: Ident =>
32+
for sym <- ident.safeSymbol.collect { case sym: TermSymbol => sym } do
33+
// check that sym is local
34+
// and check that no owners of sym is cls
35+
if !alreadySeen.contains(sym) then
36+
if sym.isLocal then
37+
if !ownersIsCls(sym) then capture += sym
38+
if sym.isMethod || sym.isLazyVal then loopCollect(sym)(sym.tree.foreach(traverse))
39+
else if sym.isModuleVal then loopCollect(sym)(sym.moduleClass.flatMap(_.tree).foreach(traverse))
40+
case _ => super.traverse(tree)
41+
42+
def ownersIsCls(sym: Symbol): Boolean =
43+
sym.owner match
44+
case owner: Symbol =>
45+
if owner == cls then true
46+
else ownersIsCls(owner)
47+
case null => false

src/main/scala/ch/epfl/scala/decoder/internal/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Definitions(using ctx: Context):
1515
val PartialFunctionClass = scalaPackage.getDecl(typeName("PartialFunction")).get.asClass
1616
val AbstractPartialFunctionClass = scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass
1717
val SerializableClass = javaIoPackage.getDecl(typeName("Serializable")).get.asClass
18+
val javaLangEnumClass = javaLangPackage.getDecl(typeName("Enum")).get.asClass
1819

1920
val SerializedLambdaType: Type = TypeRef(javaLangInvokePackage.packageRef, typeName("SerializedLambda"))
2021
val DeserializeLambdaType = MethodType(List(SimpleName("arg0")), List(SerializedLambdaType), ObjectType)

src/main/scala/ch/epfl/scala/decoder/internal/LiftedTreeCollector.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import tastyquery.Contexts.*
88
import tastyquery.SourcePosition
99
import tastyquery.Types.*
1010
import tastyquery.Traversers
11-
import tastyquery.Exceptions.NonMethodReferenceException
1211
import ch.epfl.scala.decoder.ThrowOrWarn
1312

1413
/**

src/main/scala/ch/epfl/scala/decoder/internal/Patterns.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,45 @@ object Patterns:
135135
"(.+)\\$i\\d+".r.unapplySeq(xs(0)).map(_(0)).getOrElse(xs(0))
136136
}
137137

138+
object LazyVal:
139+
def unapply(field: binary.Field): Option[String] = unapply(field.decodedName)
140+
141+
def unapply(name: String): Option[String] =
142+
"""(.*)\$lzy\d+""".r.unapplySeq(name).map(xs => xs(0).stripSuffix("$"))
143+
144+
object Module:
145+
def unapply(field: binary.Field): Boolean = field.name == "MODULE$"
146+
147+
object Offset:
148+
def unapply(field: binary.Field): Option[Int] =
149+
"""OFFSET\$(?:_m_)?(\d+)""".r.unapplySeq(field.name).map(xs => xs(0).toInt)
150+
151+
object OuterField:
152+
def unapply(field: binary.Field): Boolean = field.name == "$outer"
153+
154+
object SerialVersionUID:
155+
def unapply(field: binary.Field): Boolean = field.name == "serialVersionUID"
156+
157+
object AnyValCapture:
158+
def unapply(field: binary.Field): Boolean =
159+
field.name.matches("\\$this\\$\\d+")
160+
161+
object Capture:
162+
def unapply(field: binary.Field): Option[Seq[String]] =
163+
field.extractFromDecodedNames("(.+)\\$\\d+".r)(xs => xs(0))
164+
165+
object LazyValBitmap:
166+
def unapply(field: binary.Field): Option[String] =
167+
"(.+)bitmap\\$\\d+".r.unapplySeq(field.decodedName).map(xs => xs(0))
168+
169+
extension (field: binary.Field)
170+
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
171+
val extracted = field.unexpandedDecodedNames
172+
.flatMap(regex.unapplySeq)
173+
.map(extract)
174+
.distinct
175+
if extracted.nonEmpty then Some(extracted) else None
176+
138177
extension (method: binary.Method)
139178
private def extractFromDecodedNames[T](regex: Regex)(extract: List[String] => T): Option[Seq[T]] =
140179
val extracted = method.unexpandedDecodedNames

src/main/scala/ch/epfl/scala/decoder/internal/extensions.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ extension (symbol: Symbol)
2121
def isInline = symbol.isTerm && symbol.asTerm.isInline
2222
def nameStr: String = symbol.name.toString
2323

24+
def outerClass: Option[ClassSymbol] =
25+
symbol.owner match
26+
case null => None
27+
case owner: ClassSymbol => Some(owner)
28+
case owner => owner.outerClass
29+
2430
def showBasic =
2531
val span = symbol.tree.map(_.pos) match
2632
case Some(pos) if pos.isFullyDefined =>
@@ -307,3 +313,14 @@ extension (method: DecodedMethod)
307313
case _: DecodedMethod.SAMOrPartialFunctionConstructor => true
308314
case method: DecodedMethod.InlinedMethod => method.underlying.isGenerated
309315
case _ => false
316+
317+
extension (field: DecodedField)
318+
def isGenerated: Boolean =
319+
field match
320+
case field: DecodedField.ValDef => false
321+
case field: DecodedField.ModuleVal => true
322+
case field: DecodedField.LazyValOffset => true
323+
case field: DecodedField.Outer => true
324+
case field: DecodedField.SerialVersionUID => true
325+
case field: DecodedField.Capture => true
326+
case field: DecodedField.LazyValBitmap => true

0 commit comments

Comments
 (0)