Skip to content

Commit c438299

Browse files
committed
Partial Implementation of Deriving functionality
1 parent 5885e59 commit c438299

File tree

8 files changed

+634
-39
lines changed

8 files changed

+634
-39
lines changed

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ class SymUtils(val self: Symbol) extends AnyVal {
135135

136136
/** If this is a sealed class, its known children */
137137
def children(implicit ctx: Context): List[Symbol] = {
138-
if (self.isType) self.setFlag(ChildrenQueried)
138+
if (self.isType)
139+
self.setFlag(ChildrenQueried)
139140
self.annotations.collect {
140141
case Annotation.Child(child) => child
141142
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ trait Checking {
916916
* @param cdef the enum companion object class
917917
* @param enumCtx the context immediately enclosing the corresponding enum
918918
*/
919-
private def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(implicit ctx: Context): Unit = {
919+
def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(implicit ctx: Context): Unit = {
920920

921921
def checkCaseOrDefault(stat: Tree, caseCtx: Context) = {
922922

@@ -971,24 +971,13 @@ trait Checking {
971971
case _ =>
972972
}
973973
}
974-
975-
/** Check all enum cases in all enum companions in `stats` for legal accesses.
976-
* @param enumContexts a map from`enum` symbols to the contexts enclosing their definitions
977-
*/
978-
def checkEnumCompanions(stats: List[Tree], enumContexts: collection.Map[Symbol, Context])(implicit ctx: Context): List[Tree] = {
979-
for (stat @ TypeDef(_, _) <- stats)
980-
if (stat.symbol.is(Module))
981-
for (enumContext <- enumContexts.get(stat.symbol.linkedClass))
982-
checkEnumCaseRefsLegal(stat, enumContext)
983-
stats
984-
}
985974
}
986975

987976
trait ReChecking extends Checking {
988977
import tpd._
989978
override def checkEnum(cdef: untpd.TypeDef, cls: Symbol, parent: Symbol)(implicit ctx: Context): Unit = ()
990979
override def checkRefsLegal(tree: tpd.Tree, badOwner: Symbol, allowed: (Name, Symbol) => Boolean, where: String)(implicit ctx: Context): Unit = ()
991-
override def checkEnumCompanions(stats: List[Tree], enumContexts: collection.Map[Symbol, Context])(implicit ctx: Context): List[Tree] = stats
980+
override def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(implicit ctx: Context): Unit = ()
992981
}
993982

994983
trait NoChecking extends ReChecking {
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
package dotty.tools
2+
package dotc
3+
package typer
4+
5+
import core._
6+
import ast._
7+
import ast.Trees._
8+
import StdNames._
9+
import Contexts._, Symbols._, Types._, SymDenotations._, Names._, NameOps._, Flags._, Decorators._
10+
import NameKinds.DefaultGetterName
11+
import ast.desugar, ast.desugar._
12+
import ProtoTypes._
13+
import util.Positions._
14+
import util.Property
15+
import collection.mutable
16+
import tpd.ListOfTreeDecorator
17+
import config.Config
18+
import config.Printers.typr
19+
import Annotations._
20+
import Inferencing._
21+
import transform.ValueClasses._
22+
import transform.TypeUtils._
23+
import transform.SymUtils._
24+
import reporting.diagnostic.messages._
25+
26+
trait Deriving { this: Typer =>
27+
28+
class Deriver(cls: ClassSymbol)(implicit ctx: Context) {
29+
30+
private var synthetics = new mutable.ListBuffer[Symbol]
31+
32+
private def caseShape(sym: Symbol): Type = {
33+
val (constr, elems) =
34+
sym match {
35+
case caseClass: ClassSymbol =>
36+
caseClass.primaryConstructor.info match {
37+
case info: PolyType =>
38+
def instantiate(implicit ctx: Context) = {
39+
val poly = constrained(info, untpd.EmptyTree, alwaysAddTypeVars = true)._1
40+
val mono @ MethodType(_) = poly.resultType
41+
val resType = mono.finalResultType
42+
resType <:< cls.appliedRef
43+
val tparams = poly.paramRefs
44+
val variances = caseClass.typeParams.map(_.paramVariance)
45+
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
46+
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
47+
(resType.substParams(poly, instanceTypes),
48+
mono.paramInfos.map(_.substParams(poly, instanceTypes)))
49+
}
50+
instantiate(ctx.fresh.setNewTyperState().setOwner(caseClass))
51+
case info: MethodType =>
52+
(cls.typeRef, info.paramInfos)
53+
case _ =>
54+
(cls.typeRef, Nil)
55+
}
56+
case _ =>
57+
(sym.termRef, Nil)
58+
}
59+
val elemShape = (elems :\ (defn.UnitType: Type))(defn.PairType.appliedTo(_, _))
60+
defn.ShapeCaseType.appliedTo(constr, elemShape)
61+
}
62+
63+
lazy val children = cls.children.sortBy(_.pos.start)
64+
65+
private def sealedShape: Type = {
66+
val cases = children.map(caseShape)
67+
val casesShape = (cases :\ (defn.UnitType: Type))(defn.PairType.appliedTo(_, _))
68+
defn.ShapeCasesType.appliedTo(casesShape)
69+
}
70+
71+
lazy val shapeWithClassParams: Type =
72+
if (cls.is(Case)) caseShape(cls)
73+
else if (cls.is(Sealed)) sealedShape
74+
else NoType
75+
76+
lazy val shape: Type = shapeWithClassParams match {
77+
case delayed: LazyRef => HKTypeLambda.fromParams(cls.typeParams, delayed.ref)
78+
case NoType => NoType
79+
}
80+
81+
lazy val lazyShape: Type = shapeWithClassParams match {
82+
case delayed: LazyRef => HKTypeLambda.fromParams(cls.typeParams, delayed)
83+
case NoType => NoType
84+
}
85+
86+
class ShapeCompleter extends TypeParamsCompleter {
87+
88+
override def completerTypeParams(sym: Symbol)(implicit ctx: Context) = cls.typeParams
89+
90+
def completeInCreationContext(denot: SymDenotation) = {
91+
val shape0 = shapeWithClassParams
92+
val tparams = cls.typeParams
93+
val abstractedShape =
94+
if (!shape0.exists) {
95+
ctx.error(em"Cannot derive for $cls; it is neither sealed nor a case class or object", cls.pos)
96+
UnspecifiedErrorType
97+
}
98+
else if (tparams.isEmpty)
99+
shape0
100+
else
101+
HKTypeLambda(tparams.map(_.name.withVariance(0)))(
102+
tl => tparams.map(tparam => tl.integrate(tparams, tparam.info).bounds),
103+
tl => tl.integrate(tparams, shape0))
104+
denot.info = TypeAlias(abstractedShape)
105+
}
106+
107+
def complete(denot: SymDenotation)(implicit ctx: Context) =
108+
completeInCreationContext(denot)
109+
}
110+
111+
private def add(sym: Symbol): sym.type = {
112+
ctx.enter(sym)
113+
synthetics += sym
114+
sym
115+
}
116+
117+
/** Enter type class instance with given name and info in current scope, provided
118+
* an instance woth the same name does not exist already.
119+
*/
120+
private def addDerivedInstance(clsName: Name, info: Type, reportErrors: Boolean) = {
121+
val instanceName = s"derived$$$clsName".toTermName
122+
if (ctx.denotNamed(instanceName).exists) {
123+
if (reportErrors) ctx.error(i"duplicate typeclass derivation for $clsName")
124+
}
125+
else
126+
add(ctx.newSymbol(ctx.owner, instanceName, Synthetic | Method, info, coord = cls.pos))
127+
}
128+
129+
/* Check derived type tree `derived` for the following well-formedness conditions:
130+
* (1) It must be a class type with a stable prefix (@see checkClassTypeWithStablePrefix)
131+
* (2) It must have exactly one type parameter
132+
* If it passes the checks, enter a typeclass instance for it in the current scope.
133+
*/
134+
private def processDerivedInstance(derived: untpd.Tree): Unit = {
135+
val uncheckedType = typedAheadType(derived, AnyTypeConstructorProto).tpe.dealiasKeepAnnots
136+
val derivedType = checkClassType(uncheckedType, derived.pos, traitReq = false, stablePrefixReq = true)
137+
val nparams = derivedType.classSymbol.typeParams.length
138+
if (nparams == 1) {
139+
val typeClass = derivedType.classSymbol
140+
val firstKindedParams = cls.typeParams.filterNot(_.info.isLambdaSub)
141+
val evidenceParamInfos =
142+
for (param <- firstKindedParams) yield derivedType.appliedTo(param.typeRef)
143+
val resultType = derivedType.appliedTo(cls.appliedRef)
144+
val instanceInfo =
145+
if (evidenceParamInfos.isEmpty) ExprType(resultType)
146+
else PolyType.fromParams(firstKindedParams, ImplicitMethodType(evidenceParamInfos, resultType))
147+
addDerivedInstance(derivedType.typeSymbol.name, instanceInfo, reportErrors = true)
148+
}
149+
else
150+
ctx.error(
151+
i"derived class $derivedType should have one type paramater but has $nparams",
152+
derived.pos)
153+
}
154+
155+
private def addShape(): Unit =
156+
if (!ctx.denotNamed(tpnme.Shape).exists) {
157+
val shapeSym = add(ctx.newSymbol(ctx.owner, tpnme.Shape, EmptyFlags, new ShapeCompleter))
158+
val shapedCls = defn.ShapedClass
159+
val lazyShapedInfo = new LazyType {
160+
def complete(denot: SymDenotation)(implicit ctx: Context) = {
161+
val tparams = cls.typeParams
162+
val shapedType = shapedCls.typeRef.appliedTo(
163+
cls.appliedRef,
164+
shapeSym.typeRef.appliedTo(tparams.map(_.typeRef)))
165+
denot.info = PolyType.fromParams(tparams, shapedType).ensureMethodic
166+
}
167+
}
168+
addDerivedInstance(shapedCls.name, lazyShapedInfo, reportErrors = false)
169+
}
170+
171+
def enterDerived(derived: List[untpd.Tree]) = {
172+
derived.foreach(processDerivedInstance(_))
173+
addShape()
174+
}
175+
176+
def implementedClass(instance: Symbol) =
177+
instance.info.stripPoly.finalResultType.classSymbol
178+
179+
def typeclassInstance(sym: Symbol)(tparamRefs: List[Type])(paramRefss: List[List[tpd.Tree]]): tpd.Tree = {
180+
val tparams = tparamRefs.map(_.typeSymbol.asType)
181+
val params = if (paramRefss.isEmpty) Nil else paramRefss.head.map(_.symbol.asTerm)
182+
val typeCls = implementedClass(sym)
183+
tpd.ref(defn.Predef_undefinedR) // TODO: flesh out
184+
}
185+
186+
def syntheticDef(sym: Symbol): tpd.Tree =
187+
if (sym.isType) tpd.TypeDef(sym.asType)
188+
else tpd.polyDefDef(sym.asTerm, typeclassInstance(sym))
189+
190+
def finalize(stat: tpd.TypeDef): tpd.Tree = {
191+
val templ @ Template(_, _, _, _) = stat.rhs
192+
val newDefs = synthetics.map(syntheticDef)
193+
tpd.cpy.TypeDef(stat)(
194+
rhs = tpd.cpy.Template(templ)(body = templ.body ++ newDefs))
195+
}
196+
}
197+
}

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class Namer { typer: Typer =>
194194
val TypedAhead: Property.Key[tpd.Tree] = new Property.Key
195195
val ExpandedTree: Property.Key[untpd.Tree] = new Property.Key
196196
val SymOfTree: Property.Key[Symbol] = new Property.Key
197+
val DerivingCompanion: Property.Key[Unit] = new Property.Key
198+
val Deriver: Property.Key[typer.Deriver] = new Property.Key
197199

198200
/** A partial map from unexpanded member and pattern defs and to their expansions.
199201
* Populated during enterSyms, emptied during typer.
@@ -531,7 +533,7 @@ class Namer { typer: Typer =>
531533
}
532534

533535
/** Merge the module class `modCls` in the expanded tree of `mdef` with the
534-
* body and derived clause of the syntehtic module class `fromCls`.
536+
* body and derived clause of the synthetic module class `fromCls`.
535537
*/
536538
def mergeModuleClass(mdef: Tree, modCls: TypeDef, fromCls: TypeDef): TypeDef = {
537539
var res: TypeDef = null
@@ -542,8 +544,13 @@ class Namer { typer: Typer =>
542544
val modTempl = modCls.rhs.asInstanceOf[Template]
543545
res = cpy.TypeDef(modCls)(
544546
rhs = cpy.Template(modTempl)(
545-
derived = fromTempl.derived ++ modTempl.derived,
547+
derived = if (fromTempl.derived.nonEmpty) fromTempl.derived else modTempl.derived,
546548
body = fromTempl.body ++ modTempl.body))
549+
if (fromTempl.derived.nonEmpty) {
550+
if (modTempl.derived.nonEmpty)
551+
ctx.error(em"a class and its companion cannot both have `derives' clauses", mdef.pos)
552+
res.putAttachment(DerivingCompanion, ())
553+
}
547554
res
548555
}
549556
else tree
@@ -810,7 +817,8 @@ class Namer { typer: Typer =>
810817
cls.addAnnotation(Annotation.Child(child))
811818
else
812819
ctx.error(em"""children of ${cls} were already queried before $sym was discovered.
813-
|As a remedy, you could move $sym on the same nesting level as $cls.""")
820+
|As a remedy, you could move $sym on the same nesting level as $cls.""",
821+
child.pos)
814822
}
815823
}
816824

@@ -946,23 +954,6 @@ class Namer { typer: Typer =>
946954
}
947955
}
948956

949-
/* Check derived type tree `derived` for the following well-formedness conditions:
950-
* (1) It must be a class type with a stable prefix (@see checkClassTypeWithStablePrefix)
951-
* (2) It must have exactly one type parameter
952-
* If it passes the checks, enter a typeclass instance for it in the class scope.
953-
*/
954-
def addDerivedInstance(derived: untpd.Tree): Unit = {
955-
val uncheckedType = typedAheadType(derived, AnyTypeConstructorProto).tpe.dealiasKeepAnnots
956-
val derivedType = checkClassType(uncheckedType, derived.pos, traitReq = false, stablePrefixReq = true)
957-
val nparams = derivedType.typeSymbol.typeParams.length
958-
if (nparams == 1)
959-
println(i"todo: add derived instance $derived")
960-
else
961-
ctx.error(
962-
i"derived class $derivedType should have one type paramater but has $nparams",
963-
derived.pos)
964-
}
965-
966957
addAnnotations(denot.symbol)
967958

968959
val selfInfo: TypeOrSymbol =
@@ -980,20 +971,29 @@ class Namer { typer: Typer =>
980971
val tempInfo = new TempClassInfo(cls.owner.thisType, cls, decls, selfInfo)
981972
denot.info = tempInfo
982973

974+
val localCtx = ctx.inClassContext(selfInfo)
975+
983976
// Ensure constructor is completed so that any parameter accessors
984977
// which have type trees deriving from its parameters can be
985978
// completed in turn. Note that parent types access such parameter
986979
// accessors, that's why the constructor needs to be completed before
987980
// the parent types are elaborated.
988981
index(constr)
989-
index(rest)(ctx.inClassContext(selfInfo))
982+
index(rest)(localCtx)
990983
symbolOfTree(constr).ensureCompleted()
991984

992985
val parentTypes = defn.adjustForTuple(cls, cls.typeParams,
993986
ensureFirstIsClass(parents.map(checkedParentType(_)), cls.pos))
994987
typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %")
995988

996-
impl.derived.foreach(addDerivedInstance)
989+
if (impl.derived.nonEmpty) {
990+
val derivingClass =
991+
if (original.removeAttachment(DerivingCompanion).isDefined) cls.companionClass.asClass
992+
else cls
993+
val deriver = new Deriver(derivingClass)(localCtx)
994+
deriver.enterDerived(impl.derived)
995+
original.putAttachment(Deriver, deriver)
996+
}
997997

998998
val finalSelfInfo: TypeOrSymbol =
999999
if (cls.isOpaqueCompanion) {

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ object ProtoTypes {
446446
* If the constraint contains already some of these parameters in its domain,
447447
* make a copy of the type lambda and add the copy's type parameters instead.
448448
* Return either the original type lambda, or the copy, if one was made.
449-
* Also, if `owningTree` is non-empty ot `alwaysAddTypeVars` is true, add a type variable
449+
* Also, if `owningTree` is non-empty or `alwaysAddTypeVars` is true, add a type variable
450450
* for each parameter.
451451
* @return The added type lambda, and the list of created type variables.
452452
*/

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ class Typer extends Namer
8282
with Implicits
8383
with Inferencing
8484
with Dynamic
85-
with Checking {
85+
with Checking
86+
with Deriving {
8687

8788
import Typer._
8889
import tpd.{cpy => _, _}
@@ -1650,6 +1651,9 @@ class Typer extends Namer
16501651
if (ctx.mode.is(Mode.Interactive) && ctx.settings.YretainTrees.value)
16511652
cls.rootTreeOrProvider = cdef1
16521653

1654+
for (deriver <- cdef.removeAttachment(Deriver))
1655+
cdef1.putAttachment(Deriver, deriver)
1656+
16531657
cdef1
16541658

16551659
// todo later: check that
@@ -2055,7 +2059,18 @@ class Typer extends Namer
20552059
val exprOwnerOpt = if (exprOwner == ctx.owner) None else Some(exprOwner)
20562060
ctx.withProperty(ExprOwner, exprOwnerOpt)
20572061
}
2058-
checkEnumCompanions(traverse(stats)(localCtx), enumContexts)
2062+
def finalize(stat: Tree)(implicit ctx: Context): Tree = stat match {
2063+
case stat: TypeDef if stat.symbol.is(Module) =>
2064+
for (enumContext <- enumContexts.get(stat.symbol.linkedClass))
2065+
checkEnumCaseRefsLegal(stat, enumContext)
2066+
stat.removeAttachment(Deriver) match {
2067+
case Some(deriver) => deriver.finalize(stat)
2068+
case None => stat
2069+
}
2070+
case _ =>
2071+
stat
2072+
}
2073+
traverse(stats)(localCtx).mapConserve(finalize)
20592074
}
20602075

20612076
/** Given an inline method `mdef`, the method rewritten so that its body

0 commit comments

Comments
 (0)