Skip to content

Commit 58b8aaa

Browse files
authored
Merge pull request #5843 from dotty-staging/derive-multiversal
Base multiversal equality on typeclass derivation
2 parents f04b2ea + 4bc1514 commit 58b8aaa

36 files changed

+450
-351
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+2-41
Original file line numberDiff line numberDiff line change
@@ -536,45 +536,6 @@ object desugar {
536536
if (isEnum)
537537
parents1 = parents1 :+ ref(defn.EnumType)
538538

539-
// The Eq instance for an Enum class. For an enum class
540-
//
541-
// enum class C[T1, ..., Tn]
542-
//
543-
// we generate:
544-
//
545-
// implicit def eqInstance[T1$1, ..., Tn$1, T1$2, ..., Tn$2](implicit
546-
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
547-
// : Eq[C[T1$, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
548-
//
549-
// Higher-kinded type arguments `Ti` are omitted as evidence parameters.
550-
//
551-
// FIXME: This is too simplistic. Instead of just generating evidence arguments
552-
// for every first-kinded type parameter, we should look instead at the
553-
// actual types occurring in cases and derive parameters from these. E.g. in
554-
//
555-
// enum HK[F[_]] {
556-
// case C1(x: F[Int]) extends HK[F[Int]]
557-
// case C2(y: F[String]) extends HL[F[Int]]
558-
//
559-
// we would need evidence parameters for `F[Int]` and `F[String]`
560-
// We should generate Eq instances with the techniques
561-
// of typeclass derivation once that is available.
562-
def eqInstance = {
563-
val leftParams = constrTparams.map(derivedTypeParam(_, "$1"))
564-
val rightParams = constrTparams.map(derivedTypeParam(_, "$2"))
565-
val subInstances =
566-
for ((param1, param2) <- leftParams `zip` rightParams if !isHK(param1))
567-
yield appliedRef(ref(defn.EqType), List(param1, param2), widenHK = true)
568-
DefDef(
569-
name = nme.eqInstance,
570-
tparams = leftParams ++ rightParams,
571-
vparamss = if (subInstances.isEmpty) Nil else List(makeImplicitParameters(subInstances)),
572-
tpt = appliedTypeTree(ref(defn.EqType),
573-
appliedRef(classTycon, leftParams) :: appliedRef(classTycon, rightParams) :: Nil),
574-
rhs = ref(defn.EqModule.termRef)).withFlags(Synthetic | Implicit)
575-
}
576-
def eqInstances = if (isEnum) eqInstance :: Nil else Nil
577-
578539
// derived type classes of non-module classes go to their companions
579540
val (clsDerived, companionDerived) =
580541
if (mods.is(Module)) (impl.derived, Nil) else (Nil, impl.derived)
@@ -593,7 +554,7 @@ object desugar {
593554
mdefs
594555
}
595556

596-
val companionMembers = defaultGetters ::: eqInstances ::: enumCases
557+
val companionMembers = defaultGetters ::: enumCases
597558

598559
// The companion object definitions, if a companion is needed, Nil otherwise.
599560
// companion definitions include:
@@ -643,7 +604,7 @@ object desugar {
643604
}
644605
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
645606
}
646-
else if (companionMembers.nonEmpty || companionDerived.nonEmpty)
607+
else if (companionMembers.nonEmpty || companionDerived.nonEmpty || isEnum)
647608
companionDefs(anyRef, companionMembers)
648609
else if (isValueClass) {
649610
impl.constr.vparamss match {

compiler/src/dotty/tools/dotc/core/Definitions.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -734,11 +734,11 @@ class Definitions {
734734
lazy val TastyReflectionModule: TermSymbol = ctx.requiredModule("scala.tasty.Reflection")
735735
lazy val TastyReflection_macroContext: TermSymbol = TastyReflectionModule.requiredMethod("macroContext")
736736

737-
lazy val EqType: TypeRef = ctx.requiredClassRef("scala.Eq")
738-
def EqClass(implicit ctx: Context): ClassSymbol = EqType.symbol.asClass
739-
def EqModule(implicit ctx: Context): Symbol = EqClass.companionModule
737+
lazy val EqlType: TypeRef = ctx.requiredClassRef("scala.Eql")
738+
def EqlClass(implicit ctx: Context): ClassSymbol = EqlType.symbol.asClass
739+
def EqlModule(implicit ctx: Context): Symbol = EqlClass.companionModule
740740

741-
def Eq_eqAny(implicit ctx: Context): TermSymbol = EqModule.requiredMethod(nme.eqAny)
741+
def Eql_eqlAny(implicit ctx: Context): TermSymbol = EqlModule.requiredMethod(nme.eqlAny)
742742

743743
lazy val NotType: TypeRef = ctx.requiredClassRef("scala.implicits.Not")
744744
def NotClass(implicit ctx: Context): ClassSymbol = NotType.symbol.asClass

compiler/src/dotty/tools/dotc/core/Denotations.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ import collection.mutable.ListBuffer
6767
*/
6868
object Denotations {
6969

70-
implicit def eqDenotation: Eq[Denotation, Denotation] = Eq
70+
implicit def eqDenotation: Eql[Denotation, Denotation] = Eql.derived
7171

7272
/** A PreDenotation represents a group of single denotations or a single multi-denotation
7373
* It is used as an optimization to avoid forming MultiDenotations too eagerly.

compiler/src/dotty/tools/dotc/core/Mode.scala

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ object Mode {
5252
/** Allow GADTFlexType labelled types to have their bounds adjusted */
5353
val GADTflexible: Mode = newMode(8, "GADTflexible")
5454

55+
/** Assume -language:strictEquality */
56+
val StrictEquality: Mode = newMode(9, "StrictEquality")
57+
5558
/** We are currently printing something: avoid to produce more logs about
5659
* the printing
5760
*/

compiler/src/dotty/tools/dotc/core/Names.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object Names {
2525
def toTermName: TermName
2626
}
2727

28-
implicit def eqName: Eq[Name, Name] = Eq
28+
implicit def eqName: Eql[Name, Name] = Eql.derived
2929

3030
/** A common superclass of Name and Symbol. After bootstrap, this should be
3131
* just the type alias Name | Symbol

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ object StdNames {
418418
val equals_ : N = "equals"
419419
val error: N = "error"
420420
val eval: N = "eval"
421-
val eqAny: N = "eqAny"
421+
val eqlAny: N = "eqlAny"
422422
val ex: N = "ex"
423423
val experimental: N = "experimental"
424424
val f: N = "f"

compiler/src/dotty/tools/dotc/core/Symbols.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ trait Symbols { this: Context =>
406406

407407
object Symbols {
408408

409-
implicit def eqSymbol: Eq[Symbol, Symbol] = Eq
409+
implicit def eqSymbol: Eql[Symbol, Symbol] = Eql.derived
410410

411411
/** Tree attachment containing the identifiers in a tree as a sorted array */
412412
val Ids: Property.Key[Array[String]] = new Property.Key

compiler/src/dotty/tools/dotc/core/Types.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ object Types {
4040

4141
@sharable private[this] var nextId = 0
4242

43-
implicit def eqType: Eq[Type, Type] = Eq
43+
implicit def eqType: Eql[Type, Type] = Eql.derived
4444

4545
/** Main class representing types.
4646
*

compiler/src/dotty/tools/dotc/typer/Deriving.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ trait Deriving { this: Typer =>
193193
if (nparams == 0) Nil
194194
else if (nparams == 1) tparam :: Nil
195195
else typeClass.typeParams.map(tcparam =>
196-
tparam.copy(name = s"${tparam.name}_${tcparam.name}".toTypeName)
196+
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
197197
.asInstanceOf[TypeSymbol])
198198
}
199199
val firstKindedParamss = clsParamss.filter {

compiler/src/dotty/tools/dotc/typer/Implicits.scala

+76-29
Original file line numberDiff line numberDiff line change
@@ -699,16 +699,65 @@ trait Implicits { self: Typer =>
699699
if (ctx.inInlineMethod || enclosingInlineds.nonEmpty) ref(defn.TastyReflection_macroContext)
700700
else EmptyTree
701701

702-
/** If `formal` is of the form Eq[T, U], where no `Eq` instance exists for
703-
* either `T` or `U`, synthesize `Eq.eqAny[T, U]` as solution.
702+
/** If `formal` is of the form Eql[T, U], try to synthesize an
703+
* `Eql.eqlAny[T, U]` as solution.
704704
*/
705705
def synthesizedEq(formal: Type)(implicit ctx: Context): Tree = {
706-
//println(i"synth eq $formal / ${formal.argTypes}%, %")
706+
707+
/** Is there an `Eql[T, T]` instance, assuming -strictEquality? */
708+
def hasEq(tp: Type)(implicit ctx: Context): Boolean = {
709+
val inst = inferImplicitArg(defn.EqlType.appliedTo(tp, tp), span)
710+
!inst.isEmpty && !inst.tpe.isError
711+
}
712+
713+
/** Can we assume the eqlAny instance for `tp1`, `tp2`?
714+
* This is the case if assumedCanEqual(tp1, tp2), or
715+
* one of `tp1`, `tp2` has a reflexive `Eql` instance.
716+
*/
717+
def validEqAnyArgs(tp1: Type, tp2: Type)(implicit ctx: Context) =
718+
assumedCanEqual(tp1, tp2) || {
719+
val nestedCtx = ctx.fresh.addMode(Mode.StrictEquality)
720+
!hasEq(tp1)(nestedCtx) && !hasEq(tp2)(nestedCtx)
721+
}
722+
723+
/** Is an `Eql[cls1, cls2]` instance assumed for predefined classes `cls1`, cls2`? */
724+
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol): Boolean = {
725+
def cmpWithBoxed(cls1: ClassSymbol, cls2: ClassSymbol) =
726+
cls2 == defn.boxedType(cls1.typeRef).symbol ||
727+
cls1.isNumericValueClass && cls2.derivesFrom(defn.BoxedNumberClass)
728+
729+
if (cls1.isPrimitiveValueClass)
730+
if (cls2.isPrimitiveValueClass)
731+
cls1 == cls2 || cls1.isNumericValueClass && cls2.isNumericValueClass
732+
else
733+
cmpWithBoxed(cls1, cls2)
734+
else if (cls2.isPrimitiveValueClass)
735+
cmpWithBoxed(cls2, cls1)
736+
else if (cls1 == defn.NullClass)
737+
cls1 == cls2 || cls2.derivesFrom(defn.ObjectClass)
738+
else if (cls2 == defn.NullClass)
739+
cls1.derivesFrom(defn.ObjectClass)
740+
else
741+
false
742+
}
743+
744+
/** Some simulated `Eql` instances for predefined types. It's more efficient
745+
* to do this directly instead of setting up a lot of `Eql` instances to
746+
* interpret.
747+
*/
748+
def canComparePredefined(tp1: Type, tp2: Type) =
749+
tp1.classSymbols.exists(cls1 =>
750+
tp2.classSymbols.exists(cls2 => canComparePredefinedClasses(cls1, cls2)))
751+
707752
formal.argTypes match {
708-
case args @ (arg1 :: arg2 :: Nil)
709-
if !ctx.featureEnabled(defn.LanguageModuleClass, nme.strictEquality) &&
710-
ctx.test(implicit ctx => validEqAnyArgs(arg1, arg2)) =>
711-
ref(defn.Eq_eqAny).appliedToTypes(args).withSpan(span)
753+
case args @ (arg1 :: arg2 :: Nil) =>
754+
List(arg1, arg2).foreach(fullyDefinedType(_, "eq argument", span))
755+
if (canComparePredefined(arg1, arg2)
756+
||
757+
!strictEquality &&
758+
ctx.test(implicit ctx => validEqAnyArgs(arg1, arg2)))
759+
ref(defn.Eql_eqlAny).appliedToTypes(args).withSpan(span)
760+
else EmptyTree
712761
case _ =>
713762
EmptyTree
714763
}
@@ -737,14 +786,6 @@ trait Implicits { self: Typer =>
737786
}
738787
}
739788

740-
def hasEq(tp: Type): Boolean =
741-
inferImplicit(defn.EqType.appliedTo(tp, tp), EmptyTree, span).isSuccess
742-
743-
def validEqAnyArgs(tp1: Type, tp2: Type)(implicit ctx: Context) = {
744-
List(tp1, tp2).foreach(fullyDefinedType(_, "eqAny argument", span))
745-
assumedCanEqual(tp1, tp2) || !hasEq(tp1) && !hasEq(tp2)
746-
}
747-
748789
/** If `formal` is of the form `scala.reflect.Generic[T]` for some class type `T`,
749790
* synthesize an instance for it.
750791
*/
@@ -776,7 +817,7 @@ trait Implicits { self: Typer =>
776817
trySpecialCase(defn.QuotedTypeClass, synthesizedTypeTag,
777818
trySpecialCase(defn.GenericClass, synthesizedGeneric,
778819
trySpecialCase(defn.TastyReflectionClass, synthesizedTastyContext,
779-
trySpecialCase(defn.EqClass, synthesizedEq,
820+
trySpecialCase(defn.EqlClass, synthesizedEq,
780821
trySpecialCase(defn.ValueOfClass, synthesizedValueOf, failed))))))
781822
}
782823
}
@@ -885,16 +926,16 @@ trait Implicits { self: Typer =>
885926
em"parameter ${paramName} of $methodStr"
886927
}
887928

888-
private def assumedCanEqual(ltp: Type, rtp: Type)(implicit ctx: Context) = {
889-
def eqNullable: Boolean = {
890-
val other =
891-
if (ltp.isRef(defn.NullClass)) rtp
892-
else if (rtp.isRef(defn.NullClass)) ltp
893-
else NoType
894-
895-
(other ne NoType) && !other.derivesFrom(defn.AnyValClass)
896-
}
929+
private def strictEquality(implicit ctx: Context): Boolean =
930+
ctx.mode.is(Mode.StrictEquality) ||
931+
ctx.featureEnabled(defn.LanguageModuleClass, nme.strictEquality)
897932

933+
/** An Eql[T, U] instance is assumed
934+
* - if one of T, U is an error type, or
935+
* - if one of T, U is a subtype of the lifted version of the other,
936+
* unless strict equality is set.
937+
*/
938+
private def assumedCanEqual(ltp: Type, rtp: Type)(implicit ctx: Context) = {
898939
// Map all non-opaque abstract types to their upper bound.
899940
// This is done to check whether such types might plausibly be comparable to each other.
900941
val lift = new TypeMap {
@@ -910,14 +951,20 @@ trait Implicits { self: Typer =>
910951
if (variance > 0) mapOver(t) else t
911952
}
912953
}
913-
ltp.isError || rtp.isError || ltp <:< lift(rtp) || rtp <:< lift(ltp) || eqNullable
954+
955+
ltp.isError ||
956+
rtp.isError ||
957+
!strictEquality && {
958+
ltp <:< lift(rtp) ||
959+
rtp <:< lift(ltp)
960+
}
914961
}
915962

916963
/** Check that equality tests between types `ltp` and `rtp` make sense */
917964
def checkCanEqual(ltp: Type, rtp: Type, span: Span)(implicit ctx: Context): Unit =
918965
if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) {
919-
val res = implicitArgTree(defn.EqType.appliedTo(ltp, rtp), span)
920-
implicits.println(i"Eq witness found for $ltp / $rtp: $res: ${res.tpe}")
966+
val res = implicitArgTree(defn.EqlType.appliedTo(ltp, rtp), span)
967+
implicits.println(i"Eql witness found for $ltp / $rtp: $res: ${res.tpe}")
921968
}
922969

923970
/** Find an implicit parameter or conversion.
@@ -985,7 +1032,7 @@ trait Implicits { self: Typer =>
9851032
if (argument.isEmpty) f(resultType) else ViewProto(f(argument.tpe.widen), f(resultType))
9861033
// Not clear whether we need to drop the `.widen` here. All tests pass with it in place, though.
9871034

988-
private def isCoherent = pt.isRef(defn.EqClass)
1035+
private def isCoherent = pt.isRef(defn.EqlClass)
9891036

9901037
private val cmpContext = nestedContext()
9911038
private val cmpCandidates = (c1: Candidate, c2: Candidate) => compare(c1.ref, c2.ref, c1.level, c2.level)(cmpContext)

compiler/src/dotty/tools/dotc/typer/Namer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ class Namer { typer: Typer =>
10291029

10301030
if (impl.derived.nonEmpty) {
10311031
val (derivingClass, derivePos) = original.removeAttachment(desugar.DerivingCompanion) match {
1032-
case Some(pos) => (cls.companionClass.asClass, pos)
1032+
case Some(pos) => (cls.companionClass.orElse(cls).asClass, pos)
10331033
case None => (cls, impl.sourcePos.startPos)
10341034
}
10351035
val deriver = new Deriver(derivingClass, derivePos)(localCtx)

compiler/src/dotty/tools/dotc/util/SourceFile.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class SourceFile(val file: AbstractFile, computeContent: => Array[Char]) extends
190190
}
191191
}
192192
object SourceFile {
193-
implicit def eqSource: Eq[SourceFile, SourceFile] = Eq
193+
implicit def eqSource: Eql[SourceFile, SourceFile] = Eql.derived
194194

195195
implicit def fromContext(implicit ctx: Context): SourceFile = ctx.source
196196

compiler/test-resources/repl/i4184

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ scala> object foo { class Foo }
22
// defined object foo
33
scala> object bar { class Foo }
44
// defined object bar
5-
scala> implicit def eqFoo: Eq[foo.Foo, foo.Foo] = Eq
6-
def eqFoo: Eq[foo.Foo, foo.Foo]
5+
scala> implicit def eqFoo: Eql[foo.Foo, foo.Foo] = Eql.derived
6+
def eqFoo: Eql[foo.Foo, foo.Foo]
77
scala> object Bar { new foo.Foo == new bar.Foo }
88
1 | object Bar { new foo.Foo == new bar.Foo }
99
| ^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)