|
| 1 | +package dotty.tools |
| 2 | +package dotc |
| 3 | +package typer |
| 4 | + |
| 5 | +import core._ |
| 6 | +import ast._ |
| 7 | +import ast.Trees._ |
| 8 | +import StdNames._ |
| 9 | +import Contexts._, Symbols._, Types._, SymDenotations._, Names._, NameOps._, Flags._, Decorators._ |
| 10 | +import NameKinds.DefaultGetterName |
| 11 | +import ast.desugar, ast.desugar._ |
| 12 | +import ProtoTypes._ |
| 13 | +import util.Positions._ |
| 14 | +import util.Property |
| 15 | +import collection.mutable |
| 16 | +import tpd.ListOfTreeDecorator |
| 17 | +import config.Config |
| 18 | +import config.Printers.typr |
| 19 | +import Annotations._ |
| 20 | +import Inferencing._ |
| 21 | +import transform.ValueClasses._ |
| 22 | +import transform.TypeUtils._ |
| 23 | +import transform.SymUtils._ |
| 24 | +import reporting.diagnostic.messages._ |
| 25 | + |
| 26 | +trait Deriving { this: Typer => |
| 27 | + |
| 28 | + class Deriver(cls: ClassSymbol)(implicit ctx: Context) { |
| 29 | + |
| 30 | + private var synthetics = new mutable.ListBuffer[Symbol] |
| 31 | + |
| 32 | + private def caseShape(sym: Symbol): Type = { |
| 33 | + val (constr, elems) = |
| 34 | + sym match { |
| 35 | + case caseClass: ClassSymbol => |
| 36 | + caseClass.primaryConstructor.info match { |
| 37 | + case info: PolyType => |
| 38 | + def instantiate(implicit ctx: Context) = { |
| 39 | + val poly = constrained(info, untpd.EmptyTree, alwaysAddTypeVars = true)._1 |
| 40 | + val mono @ MethodType(_) = poly.resultType |
| 41 | + val resType = mono.finalResultType |
| 42 | + resType <:< cls.appliedRef |
| 43 | + val tparams = poly.paramRefs |
| 44 | + val variances = caseClass.typeParams.map(_.paramVariance) |
| 45 | + val instanceTypes = (tparams, variances).zipped.map((tparam, variance) => |
| 46 | + ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0)) |
| 47 | + (resType.substParams(poly, instanceTypes), |
| 48 | + mono.paramInfos.map(_.substParams(poly, instanceTypes))) |
| 49 | + } |
| 50 | + instantiate(ctx.fresh.setNewTyperState().setOwner(caseClass)) |
| 51 | + case info: MethodType => |
| 52 | + (cls.typeRef, info.paramInfos) |
| 53 | + case _ => |
| 54 | + (cls.typeRef, Nil) |
| 55 | + } |
| 56 | + case _ => |
| 57 | + (sym.termRef, Nil) |
| 58 | + } |
| 59 | + val elemShape = (elems :\ (defn.UnitType: Type))(defn.PairType.appliedTo(_, _)) |
| 60 | + defn.ShapeCaseType.appliedTo(constr, elemShape) |
| 61 | + } |
| 62 | + |
| 63 | + lazy val children = cls.children.sortBy(_.pos.start) |
| 64 | + |
| 65 | + private def sealedShape: Type = { |
| 66 | + val cases = children.map(caseShape) |
| 67 | + val casesShape = (cases :\ (defn.UnitType: Type))(defn.PairType.appliedTo(_, _)) |
| 68 | + defn.ShapeCasesType.appliedTo(casesShape) |
| 69 | + } |
| 70 | + |
| 71 | + lazy val shapeWithClassParams: Type = |
| 72 | + if (cls.is(Case)) caseShape(cls) |
| 73 | + else if (cls.is(Sealed)) sealedShape |
| 74 | + else NoType |
| 75 | + |
| 76 | + lazy val shape: Type = shapeWithClassParams match { |
| 77 | + case delayed: LazyRef => HKTypeLambda.fromParams(cls.typeParams, delayed.ref) |
| 78 | + case NoType => NoType |
| 79 | + } |
| 80 | + |
| 81 | + lazy val lazyShape: Type = shapeWithClassParams match { |
| 82 | + case delayed: LazyRef => HKTypeLambda.fromParams(cls.typeParams, delayed) |
| 83 | + case NoType => NoType |
| 84 | + } |
| 85 | + |
| 86 | + class ShapeCompleter extends TypeParamsCompleter { |
| 87 | + |
| 88 | + override def completerTypeParams(sym: Symbol)(implicit ctx: Context) = cls.typeParams |
| 89 | + |
| 90 | + def completeInCreationContext(denot: SymDenotation) = { |
| 91 | + val shape0 = shapeWithClassParams |
| 92 | + val tparams = cls.typeParams |
| 93 | + val abstractedShape = |
| 94 | + if (!shape0.exists) { |
| 95 | + ctx.error(em"Cannot derive for $cls; it is neither sealed nor a case class or object", cls.pos) |
| 96 | + UnspecifiedErrorType |
| 97 | + } |
| 98 | + else if (tparams.isEmpty) |
| 99 | + shape0 |
| 100 | + else |
| 101 | + HKTypeLambda(tparams.map(_.name.withVariance(0)))( |
| 102 | + tl => tparams.map(tparam => tl.integrate(tparams, tparam.info).bounds), |
| 103 | + tl => tl.integrate(tparams, shape0)) |
| 104 | + denot.info = TypeAlias(abstractedShape) |
| 105 | + } |
| 106 | + |
| 107 | + def complete(denot: SymDenotation)(implicit ctx: Context) = |
| 108 | + completeInCreationContext(denot) |
| 109 | + } |
| 110 | + |
| 111 | + private def add(sym: Symbol): sym.type = { |
| 112 | + ctx.enter(sym) |
| 113 | + synthetics += sym |
| 114 | + sym |
| 115 | + } |
| 116 | + |
| 117 | + /** Enter type class instance with given name and info in current scope, provided |
| 118 | + * an instance woth the same name does not exist already. |
| 119 | + */ |
| 120 | + private def addDerivedInstance(clsName: Name, info: Type, reportErrors: Boolean) = { |
| 121 | + val instanceName = s"derived$$$clsName".toTermName |
| 122 | + if (ctx.denotNamed(instanceName).exists) { |
| 123 | + if (reportErrors) ctx.error(i"duplicate typeclass derivation for $clsName") |
| 124 | + } |
| 125 | + else |
| 126 | + add(ctx.newSymbol(ctx.owner, instanceName, Synthetic | Method, info, coord = cls.pos)) |
| 127 | + } |
| 128 | + |
| 129 | + /* Check derived type tree `derived` for the following well-formedness conditions: |
| 130 | + * (1) It must be a class type with a stable prefix (@see checkClassTypeWithStablePrefix) |
| 131 | + * (2) It must have exactly one type parameter |
| 132 | + * If it passes the checks, enter a typeclass instance for it in the current scope. |
| 133 | + */ |
| 134 | + private def processDerivedInstance(derived: untpd.Tree): Unit = { |
| 135 | + val uncheckedType = typedAheadType(derived, AnyTypeConstructorProto).tpe.dealiasKeepAnnots |
| 136 | + val derivedType = checkClassType(uncheckedType, derived.pos, traitReq = false, stablePrefixReq = true) |
| 137 | + val nparams = derivedType.classSymbol.typeParams.length |
| 138 | + if (nparams == 1) { |
| 139 | + val typeClass = derivedType.classSymbol |
| 140 | + val firstKindedParams = cls.typeParams.filterNot(_.info.isLambdaSub) |
| 141 | + val evidenceParamInfos = |
| 142 | + for (param <- firstKindedParams) yield derivedType.appliedTo(param.typeRef) |
| 143 | + val resultType = derivedType.appliedTo(cls.appliedRef) |
| 144 | + val instanceInfo = |
| 145 | + if (evidenceParamInfos.isEmpty) ExprType(resultType) |
| 146 | + else PolyType.fromParams(firstKindedParams, ImplicitMethodType(evidenceParamInfos, resultType)) |
| 147 | + addDerivedInstance(derivedType.typeSymbol.name, instanceInfo, reportErrors = true) |
| 148 | + } |
| 149 | + else |
| 150 | + ctx.error( |
| 151 | + i"derived class $derivedType should have one type paramater but has $nparams", |
| 152 | + derived.pos) |
| 153 | + } |
| 154 | + |
| 155 | + private def addShape(): Unit = |
| 156 | + if (!ctx.denotNamed(tpnme.Shape).exists) { |
| 157 | + val shapeSym = add(ctx.newSymbol(ctx.owner, tpnme.Shape, EmptyFlags, new ShapeCompleter)) |
| 158 | + val shapedCls = defn.ShapedClass |
| 159 | + val lazyShapedInfo = new LazyType { |
| 160 | + def complete(denot: SymDenotation)(implicit ctx: Context) = { |
| 161 | + val tparams = cls.typeParams |
| 162 | + val shapedType = shapedCls.typeRef.appliedTo( |
| 163 | + cls.appliedRef, |
| 164 | + shapeSym.typeRef.appliedTo(tparams.map(_.typeRef))) |
| 165 | + denot.info = PolyType.fromParams(tparams, shapedType).ensureMethodic |
| 166 | + } |
| 167 | + } |
| 168 | + addDerivedInstance(shapedCls.name, lazyShapedInfo, reportErrors = false) |
| 169 | + } |
| 170 | + |
| 171 | + def enterDerived(derived: List[untpd.Tree]) = { |
| 172 | + derived.foreach(processDerivedInstance(_)) |
| 173 | + addShape() |
| 174 | + } |
| 175 | + |
| 176 | + def implementedClass(instance: Symbol) = |
| 177 | + instance.info.stripPoly.finalResultType.classSymbol |
| 178 | + |
| 179 | + def typeclassInstance(sym: Symbol)(tparamRefs: List[Type])(paramRefss: List[List[tpd.Tree]]): tpd.Tree = { |
| 180 | + val tparams = tparamRefs.map(_.typeSymbol.asType) |
| 181 | + val params = if (paramRefss.isEmpty) Nil else paramRefss.head.map(_.symbol.asTerm) |
| 182 | + val typeCls = implementedClass(sym) |
| 183 | + tpd.ref(defn.Predef_undefinedR) // TODO: flesh out |
| 184 | + } |
| 185 | + |
| 186 | + def syntheticDef(sym: Symbol): tpd.Tree = |
| 187 | + if (sym.isType) tpd.TypeDef(sym.asType) |
| 188 | + else tpd.polyDefDef(sym.asTerm, typeclassInstance(sym)) |
| 189 | + |
| 190 | + def finalize(stat: tpd.TypeDef): tpd.Tree = { |
| 191 | + val templ @ Template(_, _, _, _) = stat.rhs |
| 192 | + val newDefs = synthetics.map(syntheticDef) |
| 193 | + tpd.cpy.TypeDef(stat)( |
| 194 | + rhs = tpd.cpy.Template(templ)(body = templ.body ++ newDefs)) |
| 195 | + } |
| 196 | + } |
| 197 | +} |
0 commit comments