Skip to content

Commit b6cb517

Browse files
committed
CreateCompanionObjects transformer
A transformer that provides a convenient way to create companion objects.
1 parent 3eabe4a commit b6cb517

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer}
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Contexts.Context
6+
import scala.collection.mutable.ListBuffer
7+
import dotty.tools.dotc.core.{Scopes, Flags}
8+
import dotty.tools.dotc.core.Symbols.NoSymbol
9+
import scala.annotation.tailrec
10+
import dotty.tools.dotc.core._
11+
import Symbols._
12+
import scala.Some
13+
import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer}
14+
import dotty.tools.dotc.ast.tpd
15+
import dotty.tools.dotc.core.Contexts.Context
16+
import scala.collection.mutable
17+
import dotty.tools.dotc.core.Names.Name
18+
import NameOps._
19+
20+
/** A transformer that provides a convenient way to create companion objects
21+
*/
22+
abstract class CreateCompanionObjects(group: TreeTransformer, idx: Int) extends TreeTransform(group, idx) {
23+
24+
import tpd._
25+
26+
/** Given class definition should return true if companion object creation should be enforced
27+
*/
28+
def predicate(cls: TypeDef): Boolean
29+
30+
override def transformStats(trees: List[Tree])(implicit ctx: Context, info: TransformerInfo): List[tpd.Tree] = {
31+
@tailrec
32+
def transformStats0(trees: List[Tree], acc: ListBuffer[Tree]): List[Tree] = {
33+
trees match {
34+
case Nil => acc.toList
35+
case (claz: TypeDef) :: stats if claz.symbol.isClass && !(claz.symbol is Flags.Module) => {
36+
val moduleExists = !(claz.symbol.companionModule eq NoSymbol)
37+
if (moduleExists || !predicate(claz)) transformStats0(stats, acc += claz)
38+
else {
39+
val moduleSymbol = ctx.newCompleteModuleSymbol(claz.symbol.owner, claz.name.toTermName, Flags.Synthetic, Flags.Synthetic, List(defn.ObjectClass.typeRef), Scopes.newScope)
40+
if (moduleSymbol.owner.isClass) moduleSymbol.entered
41+
val companion = tpd.ModuleDef(moduleSymbol, List(EmptyTree))
42+
acc += claz
43+
acc += companion
44+
transformStats0(stats, acc)
45+
}
46+
}
47+
case stat :: stats => transformStats0(stats, acc += stat)
48+
}
49+
}
50+
51+
transformStats0(trees, ListBuffer())
52+
}
53+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package test.transform
2+
3+
4+
import org.junit.{Assert, Test}
5+
import test.DottyTest
6+
import dotty.tools.dotc.core._
7+
import dotty.tools.dotc.ast.{tpd, Trees}
8+
import Contexts._
9+
import Flags._
10+
import Denotations._
11+
import NameOps._
12+
import Symbols._
13+
import Types._
14+
import Decorators._
15+
import Trees._
16+
import dotty.tools.dotc.transform.TreeTransforms.{TreeTransform, TreeTransformer}
17+
import dotty.tools.dotc.transform.PostTyperTransformers.PostTyperTransformer
18+
import dotty.tools.dotc.transform.CreateCompanionObjects
19+
20+
21+
class CreateCompanionObjectsTest extends DottyTest {
22+
23+
import tpd._
24+
25+
@Test
26+
def shouldCreateNonExistingObjectsInPackage = checkCompile("frontend", "class A{} ") {
27+
(tree, context) =>
28+
implicit val ctx = context
29+
30+
val transformer = new PostTyperTransformer {
31+
override def transformations = Array(new CreateCompanionObjects(_, _) {
32+
override def predicate(cts: TypeDef): Boolean = true
33+
})
34+
35+
override def name: String = "test"
36+
}
37+
val transformed = transformer.transform(tree).toString
38+
println(transformed)
39+
val classPattern = "TypeDef(Modifiers(,,List()),A,"
40+
val classPos = transformed.indexOf(classPattern)
41+
val moduleClassPattern = "TypeDef(Modifiers(final module <synthetic>,,List()),A$"
42+
val modulePos = transformed.indexOf(moduleClassPattern)
43+
44+
Assert.assertTrue("should create non-existing objects in package",
45+
classPos < modulePos
46+
)
47+
}
48+
49+
@Test
50+
def shouldCreateNonExistingObjectsInBlock = checkCompile("frontend", "class D {def p = {class A{}; 1}} ") {
51+
(tree, context) =>
52+
implicit val ctx = context
53+
val transformer = new PostTyperTransformer {
54+
override def transformations = Array(new CreateCompanionObjects(_, _) {
55+
override def predicate(cts: TypeDef): Boolean = true
56+
})
57+
58+
override def name: String = "test"
59+
}
60+
val transformed = transformer.transform(tree).toString
61+
val classPattern = "TypeDef(Modifiers(,,List()),A,"
62+
val classPos = transformed.indexOf(classPattern)
63+
val moduleClassPattern = "TypeDef(Modifiers(final module <synthetic>,,List()),A$"
64+
val modulePos = transformed.indexOf(moduleClassPattern)
65+
66+
Assert.assertTrue("should create non-existing objects in block",
67+
classPos < modulePos
68+
)
69+
}
70+
71+
@Test
72+
def shouldCreateNonExistingObjectsInTemplate = checkCompile("frontend", "class D {class A{}; } ") {
73+
(tree, context) =>
74+
implicit val ctx = context
75+
val transformer = new PostTyperTransformer {
76+
override def transformations = Array(new CreateCompanionObjects(_, _) {
77+
override def predicate(cts: TypeDef): Boolean = true
78+
})
79+
80+
override def name: String = "test"
81+
}
82+
val transformed = transformer.transform(tree).toString
83+
val classPattern = "TypeDef(Modifiers(,,List()),A,"
84+
val classPos = transformed.indexOf(classPattern)
85+
val moduleClassPattern = "TypeDef(Modifiers(final module <synthetic>,,List()),A$"
86+
val modulePos = transformed.indexOf(moduleClassPattern)
87+
88+
Assert.assertTrue("should create non-existing objects in template",
89+
classPos < modulePos
90+
)
91+
}
92+
93+
@Test
94+
def shouldCreateOnlyIfAskedFor = checkCompile("frontend", "class DONT {class CREATE{}; } ") {
95+
(tree, context) =>
96+
implicit val ctx = context
97+
val transformer = new PostTyperTransformer {
98+
override def transformations = Array(new CreateCompanionObjects(_, _) {
99+
override def predicate(cts: TypeDef): Boolean = cts.name.toString.contains("CREATE")
100+
})
101+
102+
override def name: String = "test"
103+
}
104+
val transformed = transformer.transform(tree).toString
105+
val classPattern = "TypeDef(Modifiers(,,List()),A,"
106+
val classPos = transformed.indexOf(classPattern)
107+
val moduleClassPattern = "TypeDef(Modifiers(final module <synthetic>,,List()),CREATE$"
108+
val modulePos = transformed.indexOf(moduleClassPattern)
109+
110+
val notCreatedModulePattern = "TypeDef(Modifiers(final module <synthetic>,,List()),DONT"
111+
val notCreatedPos = transformed.indexOf(notCreatedModulePattern)
112+
113+
Assert.assertTrue("should create non-existing objects in template",
114+
classPos < modulePos && (notCreatedPos < 0)
115+
)
116+
}
117+
}

0 commit comments

Comments
 (0)