Skip to content

Add support for Java records in patterns. #19577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
55 changes: 52 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,22 @@ object desugar {
case Select(qual, tpnme.AnyVal) => isScala(qual)
case _ => false
}

def isScala(tree: Tree): Boolean = tree match {
case Ident(nme.scala) => true
case Select(Ident(nme.ROOTPKG), nme.scala) => true
case _ => false
}

def isRecord(tree: Tree): Boolean = tree match {
case Select(Select(Select(Ident(nme.ROOTPKG), nme.java), nme.lang), tpnme.Record) => true
case _ => false
}

def namePos = cdef.sourcePos.withSpan(cdef.nameSpan)

val isJavaRecord = mods.is(JavaDefined) && parents.exists(isRecord)

val isObject = mods.is(Module)
val isCaseClass = mods.is(Case) && !isObject
val isCaseObject = mods.is(Case) && isObject
Expand Down Expand Up @@ -769,6 +777,11 @@ object desugar {

val companionMembers = defaultGetters ::: enumCases

def tupleApply(params: List[untpd.Tree]): untpd.Apply = {
val fun = Select(Ident(nme.scala), s"${StdNames.str.Tuple}$arity".toTermName)
Apply(fun, params)
}

// The companion object definitions, if a companion is needed, Nil otherwise.
// companion definitions include:
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
Expand Down Expand Up @@ -801,9 +814,8 @@ object desugar {
case vparam :: Nil =>
Apply(scalaDot(nme.Option), Select(Ident(unapplyParamName), vparam.name))
case vparams =>
val tupleApply = Select(Ident(nme.scala), s"Tuple$arity".toTermName)
val members = vparams.map(vparam => Select(Ident(unapplyParamName), vparam.name))
Apply(scalaDot(nme.Option), Apply(tupleApply, members))
val members = vparams.map(param => Select(Ident(unapplyParamName), param.name))
Apply(scalaDot(nme.Option), tupleApply(members))

val hasRepeatedParam = constrVparamss.head.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
Expand Down Expand Up @@ -832,6 +844,43 @@ object desugar {
companionDefs(anyRef, companionMembers)
else if isValueClass && !isObject then
companionDefs(anyRef, Nil)
else if (isJavaRecord) {

/** Get the canonical constructor of the Java record.
*
* Java classes have a dummy constructor; see [[JavaParsers.makeTemplate]] for
* more details
*/
def canonicalConstructor(impl: Template): DefDef = {
impl.body.collectFirst {
case ddef: DefDef if ddef.name.isConstructorName && ddef.mods.is(Synthetic) =>
ddef
}.get
}

val constr1 = canonicalConstructor(impl)
val tParams = constr1.leadingTypeParams
val vParams = asTermOnly(constr1.trailingParamss).head
val arity = vParams.length

val classTypeRef = appliedRef(classTycon)

val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)

val unapplyRHS =
if (arity == 0) Literal(Constant(true))
else Ident(unapplyParam.name)

val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()

val unapplyMeth = DefDef(
nme.unapply,
joinParams(derivedTparams, (unapplyParam :: Nil) :: Nil),
unapplyResTp,
unapplyRHS
).withMods(synthetic | Inline)
companionDefs(anyRef, unapplyMeth :: Nil)
}
else Nil

enumCompanionRef match {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,9 @@ class Definitions {
@tu lazy val RuntimeTuples_isInstanceOfEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfEmptyTuple")
@tu lazy val RuntimeTuples_isInstanceOfNonEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfNonEmptyTuple")

@tu lazy val JavaRecordReflectMirrorTypeRef: TypeRef = requiredClassRef("scala.runtime.JavaRecordMirror")
@tu lazy val JavaRecordReflectMirrorModule: Symbol = requiredModule("scala.runtime.JavaRecordMirror")

@tu lazy val TupledFunctionTypeRef: TypeRef = requiredClassRef("scala.util.TupledFunction")
def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass
def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions")
Expand Down Expand Up @@ -1636,6 +1639,7 @@ class Definitions {
def isAbstractFunctionClass(cls: Symbol): Boolean = isVarArityClass(cls, str.AbstractFunction)
def isTupleClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Tuple)
def isProductClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Product)
def isJavaRecordClass(cls: Symbol): Boolean = cls.is(JavaDefined) && cls.derivesFrom(JavaRecordClass)

def isBoxedUnitClass(cls: Symbol): Boolean =
cls.isClass && (cls.owner eq ScalaRuntimePackageClass) && cls.name == tpnme.BoxedUnit
Expand Down
37 changes: 32 additions & 5 deletions compiler/src/dotty/tools/dotc/core/NamerOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ object NamerOps:
/** The flags of an `apply` method that serves as a constructor proxy */
val ApplyProxyFlags = Synthetic | ConstructorProxy | Inline | Method

/** TODO: It would be nice if this was inline. Probably want an extra flag for `Proxy`? */
val UnApplyProxyFlags = Synthetic | Method | Inline | ConstructorProxy

/** If this is a reference to a class and the reference has a stable prefix, the reference
* otherwise NoType
*/
Expand All @@ -105,6 +108,8 @@ object NamerOps:
def complete(denot: SymDenotation)(using Context): Unit =
denot.info = constr.info

// This is weird, but possible.

/** Add constructor proxy apply methods to `scope`. Proxies are for constructors
* in `cls` and they reside in `modcls`.
*/
Expand All @@ -121,6 +126,24 @@ object NamerOps:
scope
end addConstructorApplies

def addIdentUnApply(scope: MutableScope, cls: ClassSymbol, modCls: ClassSymbol)(using ctx: Context): scope.type =
def proxy(constr: Symbol): Symbol =
val typeRef = cls.typeRef
newSymbol(
modCls, nme.unapply,
// The modifiers on unapply are essentially the same as on the constructor
UnApplyProxyFlags | (constr.flagsUNSAFE & AccessFlags),
MethodType(typeRef :: Nil, typeRef),
cls.privateWithin,
// TODO: Does this work? Or are we going to run into issues because the type it not the same?
constr.coord
)

val decl = cls.info.decls.find(_.isConstructor)
scope.enter(proxy(decl))
scope
end addIdentUnApply

/** The completer of a constructor companion for class `cls`, where
* `modul` is the companion symbol and `modcls` is its class.
*/
Expand Down Expand Up @@ -150,6 +173,7 @@ object NamerOps:
newSymbol(tsym.owner, tsym.name.toTermName,
ConstructorCompanionFlags | StableRealizable | Method, ExprType(prefix.select(proxy)), coord = tsym.coord)

// TODO: Rename to addSyntheticProxies
/** Add all necessary constructor proxy symbols for members of class `cls`. This means:
*
* - if a member is a class, or type alias, that needs a constructor companion, add one,
Expand All @@ -172,17 +196,20 @@ object NamerOps:
then
classConstructorCompanion(mbr).entered
case _ =>
// TODO: What is this?
underlyingStableClassRef(mbr.info.loBound): @unchecked match
case ref: TypeRef =>
val proxy = ref.symbol.registeredCompanion
if proxy.is(ConstructorProxy) && !memberExists(cls, mbr.name.toTermName) then
typeConstructorCompanion(mbr, ref.prefix, proxy).entered

if cls.is(Module)
&& needsConstructorProxies(cls.linkedClass)
&& !memberExists(cls, nme.apply)
then
addConstructorApplies(cls.info.decls.openForMutations, cls.linkedClass.asClass, cls)
if cls.is(Module) then
if(needsConstructorProxies(cls.linkedClass) && !memberExists(cls, nme.apply)) then
addConstructorApplies(cls.info.decls.openForMutations, cls.linkedClass.asClass, cls)

if(defn.isJavaRecordClass(cls.linkedClass) && !memberExists(cls, nme.unapply)) then
addIdentUnApply(cls.info.decls.openForMutations, cls.linkedClass.asClass, cls)

end addConstructorProxies

/** Turn `modul` into a constructor companion for class `cls` */
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ object StdNames {
final val MethodParametersATTR: N = "MethodParameters"
final val LineNumberTableATTR: N = "LineNumberTable"
final val LocalVariableTableATTR: N = "LocalVariableTable"
final val RecordATTR: N = "Record" // Introduced in JEP-395
final val RuntimeVisibleAnnotationATTR: N = "RuntimeVisibleAnnotations" // RetentionPolicy.RUNTIME
final val RuntimeInvisibleAnnotationATTR: N = "RuntimeInvisibleAnnotations" // RetentionPolicy.CLASS
final val RuntimeParamAnnotationATTR: N = "RuntimeVisibleParameterAnnotations" // RetentionPolicy.RUNTIME (annotations on parameters)
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/core/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ class SymUtils:
def canAccessCtor: Boolean =
def isAccessible(sym: Symbol): Boolean = ctx.owner.isContainedIn(sym)
def isSub(sym: Symbol): Boolean = ctx.owner.ownersIterator.exists(_.derivesFrom(sym))
val ctor = self.primaryConstructor
val ctor = if defn.isJavaRecordClass(self) then self.javaCanonicalConstructor else self.primaryConstructor
(!ctor.isOneOf(Private | Protected) || isSub(self)) // we cant access the ctor because we do not extend cls
&& (!ctor.privateWithin.exists || isAccessible(ctor.privateWithin)) // check scope is compatible


def companionMirror = self.useCompanionAsProductMirror
if (!self.is(CaseClass)) "it is not a case class"

if (!(self.is(CaseClass) || defn.isJavaRecordClass(self))) "it is not a case class or record class"
else if (self.is(Abstract)) "it is an abstract class"
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
else if self.isDerivedValueClass then "it is a value class"
Expand Down Expand Up @@ -146,7 +147,7 @@ class SymUtils:
&& (!self.is(Method) || self.is(Accessor))

def useCompanionAsProductMirror(using Context): Boolean =
self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case)
self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case) && !defn.isJavaRecordClass(self)

def useCompanionAsSumMirror(using Context): Boolean =
def companionExtendsSum(using Context): Boolean =
Expand Down Expand Up @@ -249,6 +250,16 @@ class SymUtils:
def caseAccessors(using Context): List[Symbol] =
self.info.decls.filter(_.is(CaseAccessor))

// TODO: I'm convinced that we need to introduce a flag to get the canonical constructor.
// we should also check whether the names are erased in the ctor. If not, we should
// be able to infer the components directly from the constructor.
def javaCanonicalConstructor(using Context): Symbol =
self.info.decls.filter(_.isConstructor).tail.head

// TODO: Check if `Synthetic` is stamped properly
def javaRecordComponents(using Context): List[Symbol] =
self.info.decls.filter(_.is(ParamAccessor))

def getter(using Context): Symbol =
if (self.isGetter) self else accessorNamed(self.asTerm.name.getterName)

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/SymbolLoaders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ abstract class SymbolLoader extends LazyType { self =>
if !denot.isCompleted &&
!denot.completer.isInstanceOf[SymbolLoaders.SecondCompleter] then
if denot.is(ModuleClass) && NamerOps.needsConstructorProxies(other) then
// TODO: What to do here?
NamerOps.makeConstructorCompanion(denot.sourceModule.asTerm, other.asClass)
denot.resetFlag(Touched)
else
Expand Down
34 changes: 31 additions & 3 deletions compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -895,13 +895,14 @@ class ClassfileParser(
var exceptions: List[NameOrString] = Nil
var annotations: List[Annotation] = Nil
var namedParams: Map[Int, TermName] = Map.empty

def complete(tp: Type, isVarargs: Boolean = false)(using Context): Type = {
val updatedType =
if sig == null then tp
else {
val newType = sigToType(sig, sym, isVarargs)
if (ctx.debug && ctx.verbose)
println("" + sym + "; signature = " + sig + " type = " + newType)
if (ctx.verbose)
report.debuglog("" + sym + "; signature = " + sig + " type = " + newType)
newType
}

Expand All @@ -911,7 +912,7 @@ class ClassfileParser(
if ct != null then ConstantType(ct) else updatedType
else updatedType

annotations.foreach(annot => sym.addAnnotation(annot))
annotations.foreach(sym.addAnnotation)

exceptions.foreach { ex =>
val cls = getClassSymbol(ex.name)
Expand Down Expand Up @@ -995,11 +996,38 @@ class ClassfileParser(
report.log(s"$sym in ${sym.owner} is a java 8+ default method.")
}

// https://docs.oracle.com/javase/specs/jvms/se15/preview/specs/records-jvms.html
case tpnme.RecordATTR =>
// Each member is /not/ sythetic on purpose
parseRecord()
case _ =>
}
in.bp = end
}

/**
* Parse the `Record` attribute.
*
* The `Record` attribute contains the _name_ and _descriptor_ for
* each component within the `Record` in the order of the canonical
* constructor.
*/
def parseRecord(): Unit = {
val componentsCount = in.nextChar.toInt
val components = for (i <- 0 until componentsCount) yield
val nameIndex = in.nextChar.toInt
val descriptorIndex = in.nextChar.toInt

val name = pool.getName(nameIndex)
val descriptor = pool.getName(descriptorIndex)

instanceScope.lookup(name.name).setFlag(Flags.ParamAccessor)

// TODO: Double /where/ we want these attributes to sit.
skipAttributes()
(name, name)
}

/**
* Parse the "Exceptions" attribute which denotes the exceptions
* thrown by a method.
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,11 @@ object JavaParsers {
fieldsByName -= name
end for

// The `Synthetic` flag here is only used by the `Namer` to evict overriden symbols during mixed-compilation
val accessors =
(for (name, (tpt, annots)) <- fieldsByName yield
DefDef(name, List(Nil), tpt, unimplementedExpr)
.withMods(Modifiers(Flags.JavaDefined | Flags.Method | Flags.Synthetic))
.withMods(Modifiers(Flags.JavaDefined | Flags.Method | Flags.ParamAccessor | Flags.Synthetic))
).toList

// generate the canonical constructor
Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import core.*
import Flags.*
import Contexts.*
import Symbols.*
import StdNames.*
import Decorators.*

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees.*
Expand All @@ -14,6 +16,8 @@ import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
import dotty.tools.dotc.staging.StagingLevel

import config.Printers.inlining

import scala.collection.mutable.ListBuffer

/** Inlines all calls to inline methods that are not in an inline method or a quote */
Expand Down Expand Up @@ -55,8 +59,9 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
}

def newTransformer(using Context): Transformer = new Transformer {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
override def transform(tree: tpd.Tree)(using Context): tpd.Tree = {
new InliningTreeMap().transform(tree)
}
}

private class InliningTreeMap extends TreeMapWithImplicits {
Expand Down Expand Up @@ -88,6 +93,11 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {

flatTree(trees2)
else super.transform(tree)
case Apply(fun, args) if fun.symbol.name == nme.unapply && fun.symbol.is(ConstructorProxy) =>
Copy link
Contributor Author

@yilinwei yilinwei Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this transformation happen in the Typer just after we've created the unapply object instead - this would mean the pickler wouldn't see the proxy; would that be better?

// TODO: Add for using it as a value
val tree1 = args.head
inlining.println(i"reducing unapply proxy: $tree -> $tree1")
tree1
case _: Typed | _: Block =>
super.transform(tree)
case _ if Inlines.needsInlining(tree) =>
Expand Down
14 changes: 12 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,15 @@ object PatternMatcher {
.select(defn.RuntimeTuples_apply)
.appliedTo(receiver, Literal(Constant(i)))

def resultTypeSym = unapp.symbol.info.resultType.typeSymbol

// TODO: Check Scala -> Java, erased?
def isJavaRecordUnapply(sym: Symbol) = defn.isJavaRecordClass(resultTypeSym)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
def recordSel(sym: Symbol) = tupleSel(sym).appliedToTermArgs(Nil)

// TODO: Move this to the correct location
if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
val isGenericTuple = defn.isTupleClass(caseClass) &&
!defn.isTupleNType(tree.tpe match { case tp: OrType => tp.join case tp => tp }) // widen even hard unions, to see if it's a union of tuples
val components = if isGenericTuple then caseAccessors.indices.toList.map(tupleApp(_, ref(scrutinee))) else caseAccessors.map(tupleSel)
Expand All @@ -369,7 +376,10 @@ object PatternMatcher {
else if unappResult.info <:< defn.NonEmptyTupleTypeRef then
val components = (0 until foldApplyTupleType(unappResult.denot.info).length).toList.map(tupleApp(_, ref(unappResult)))
matchArgsPlan(components, args, onSuccess)
else {
else if (isJavaRecordUnapply(unapp.symbol.owner)) {
val components = resultTypeSym.javaRecordComponents.map(recordSel)
matchArgsPlan(components, args, onSuccess)
} else {
assert(isGetMatch(unapp.tpe))
val argsPlan = {
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
Expand Down
Loading