Skip to content

Commit 3b7c9ff

Browse files
committed
Use infix extension to indicate right associativity with natural order
1 parent e598bef commit 3b7c9ff

File tree

23 files changed

+172
-133
lines changed

23 files changed

+172
-133
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+12-8
Original file line numberDiff line numberDiff line change
@@ -996,11 +996,18 @@ object desugar {
996996

997997
def badRightAssoc(problem: String) =
998998
report.error(em"right-associative extension method $problem", mdef.srcPos)
999-
() // extParamss ++ mdef.paramss
999+
extParamss ++ mdef.paramss
10001000

10011001
rightParam match
10021002
case ValDefs(vparam :: Nil) =>
1003-
if !vparam.mods.is(Given) then
1003+
if vparam.mods.is(Given) then
1004+
badRightAssoc("cannot start with using clause")
1005+
else if mdef.mods.is(Infix) then
1006+
// New encoding:
1007+
// we keep the extension method as is and rely on the swap of arguments at call site
1008+
extParamss ++ mdef.paramss
1009+
else
1010+
// Old encoding:
10041011
// we merge the extension parameters with the method parameters,
10051012
// swapping the operator arguments:
10061013
// e.g.
@@ -1010,16 +1017,13 @@ object desugar {
10101017
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
10111018
//
10121019
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
1013-
// val (leftTyParamsAndLeadingUsing, leftParamAndTrailingUsing) = extParamss.span(isUsingOrTypeParamClause)
1014-
() // leftTyParamsAndLeadingUsing ::: rightTyParams ::: rightParam :: leftParamAndTrailingUsing ::: paramss1
1015-
else
1016-
badRightAssoc("cannot start with using clause")
1020+
val (leftTyParamsAndLeadingUsing, leftParamAndTrailingUsing) = extParamss.span(isUsingOrTypeParamClause)
1021+
leftTyParamsAndLeadingUsing ::: rightTyParams ::: rightParam :: leftParamAndTrailingUsing ::: paramss1
10171022
case _ =>
10181023
badRightAssoc("must start with a single parameter")
10191024
case _ =>
10201025
// no value parameters, so not an infix operator.
1021-
() // extParamss ++ mdef.paramss
1022-
extParamss ++ mdef.paramss
1026+
extParamss ++ mdef.paramss
10231027
else
10241028
extParamss ++ mdef.paramss
10251029
).withMods(mdef.mods | ExtensionMethod)

compiler/src/dotty/tools/dotc/ast/Positioned.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import util.{SourceFile, SourcePosition, SrcPos}
77
import core.Contexts.*
88
import core.Decorators.*
99
import core.NameOps.*
10-
import core.Flags.{JavaDefined, ExtensionMethod}
10+
import core.Flags.{JavaDefined, ExtensionMethod, Infix}
1111
import core.StdNames.nme
1212
import ast.Trees.mods
1313
import annotation.constructorOnly
@@ -215,8 +215,8 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
215215
check(tree.trailingParamss)
216216
case tree: DefDef if tree.mods.is(ExtensionMethod) =>
217217
tree.paramss match
218-
// case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
219-
// // omit check for right-associatiove extension methods; their parameters were swapped
218+
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName && !tree.mods.is(Infix) =>
219+
// omit check for right-associatiove extension methods; their parameters were swapped
220220
case _ =>
221221
check(tree.paramss)
222222
check(tree.tpt)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+29-29
Original file line numberDiff line numberDiff line change
@@ -933,35 +933,35 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
933933
val coreSig =
934934
if isExtension then
935935
val paramss =
936-
// if tree.name.isRightAssocOperatorName then
937-
// // If you change the names of the clauses below, also change them in right-associative-extension-methods.md
938-
// // we have the following encoding of tree.paramss:
939-
// // (leftTyParams ++ leadingUsing
940-
// // ++ rightTyParams ++ rightParam
941-
// // ++ leftParam ++ trailingUsing ++ rest)
942-
// // e.g.
943-
// // extension [A](using B)(c: C)(using D)
944-
// // def %:[E](f: F)(g: G)(using H): Res = ???
945-
// // will have the following values:
946-
// // - leftTyParams = List(`[A]`)
947-
// // - leadingUsing = List(`(using B)`)
948-
// // - rightTyParams = List(`[E]`)
949-
// // - rightParam = List(`(f: F)`)
950-
// // - leftParam = List(`(c: C)`)
951-
// // - trailingUsing = List(`(using D)`)
952-
// // - rest = List(`(g: G)`, `(using H)`)
953-
// // we need to swap (rightTyParams ++ rightParam) with (leftParam ++ trailingUsing)
954-
// val (leftTyParams, rest1) = tree.paramss.span(isTypeParamClause)
955-
// val (leadingUsing, rest2) = rest1.span(isUsingClause)
956-
// val (rightTyParams, rest3) = rest2.span(isTypeParamClause)
957-
// val (rightParam, rest4) = rest3.splitAt(1)
958-
// val (leftParam, rest5) = rest4.splitAt(1)
959-
// val (trailingUsing, rest6) = rest5.span(isUsingClause)
960-
// if leftParam.nonEmpty then
961-
// leftTyParams ::: leadingUsing ::: leftParam ::: trailingUsing ::: rightTyParams ::: rightParam ::: rest6
962-
// else
963-
// tree.paramss // it wasn't a binary operator, after all.
964-
// else
936+
if tree.name.isRightAssocOperatorName && !tree.mods.is(Infix) && !tree.symbol.is(Infix) then
937+
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
938+
// we have the following encoding of tree.paramss:
939+
// (leftTyParams ++ leadingUsing
940+
// ++ rightTyParams ++ rightParam
941+
// ++ leftParam ++ trailingUsing ++ rest)
942+
// e.g.
943+
// extension [A](using B)(c: C)(using D)
944+
// def %:[E](f: F)(g: G)(using H): Res = ???
945+
// will have the following values:
946+
// - leftTyParams = List(`[A]`)
947+
// - leadingUsing = List(`(using B)`)
948+
// - rightTyParams = List(`[E]`)
949+
// - rightParam = List(`(f: F)`)
950+
// - leftParam = List(`(c: C)`)
951+
// - trailingUsing = List(`(using D)`)
952+
// - rest = List(`(g: G)`, `(using H)`)
953+
// we need to swap (rightTyParams ++ rightParam) with (leftParam ++ trailingUsing)
954+
val (leftTyParams, rest1) = tree.paramss.span(isTypeParamClause)
955+
val (leadingUsing, rest2) = rest1.span(isUsingClause)
956+
val (rightTyParams, rest3) = rest2.span(isTypeParamClause)
957+
val (rightParam, rest4) = rest3.splitAt(1)
958+
val (leftParam, rest5) = rest4.splitAt(1)
959+
val (trailingUsing, rest6) = rest5.span(isUsingClause)
960+
if leftParam.nonEmpty then
961+
leftTyParams ::: leadingUsing ::: leftParam ::: trailingUsing ::: rightTyParams ::: rightParam ::: rest6
962+
else
963+
tree.paramss // it wasn't a binary operator, after all.
964+
else
965965
tree.paramss
966966
val trailingParamss = paramss
967967
.dropWhile(isUsingOrTypeParamClause)

docs/_docs/reference/contextual/right-associative-extension-methods.md

+39-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,22 @@ single explicit term parameter (in other words, `rightParam` is present). In the
3636

3737
The Scala compiler pre-processes a right-associative infix operation such as `x +: xs`
3838
to `xs.+:(x)` if `x` is a pure expression or a call-by-name parameter and to `val y = x; xs.+:(y)` otherwise. This is necessary since a regular right-associative infix method
39-
is defined in the class of its right operand. To make up for this swap,
39+
is defined in the class of its right operand.
40+
41+
### Natural Order Right-Associative Extension Methods
42+
If the right-associative extension methods is defined as infix, then the extension is used in its natural order. The `leftParam` is the receiver and
43+
the `rightParam` is the argument. The order of the parameters is kept consistent with the order of the arguments at call site after desugaring.
44+
For instance:
45+
46+
```scala
47+
extension [T](xs: List[T])
48+
infix def +:: (x: T): List[T] = ...
49+
50+
y +:: ys // ys.::y // +::(ys)(y)
51+
```
52+
53+
### Inverted Right-Associative Extension Methods
54+
To make up for the swap in the order at call site,
4055
the expansion of right-associative extension methods performs the inverse parameter swap. More precisely, if `rightParam` is present, the total parameter sequence
4156
of the extension method's expansion is:
4257

@@ -59,6 +74,27 @@ For instance, the `+::` method above would become
5974
```
6075

6176
This expansion has to be kept in mind when writing right-associative extension
62-
methods with inter-parameter dependencies.
77+
methods with inter-parameter dependencies. To avoid this limitation use _natural order right-associative extension methods_.
78+
79+
This expansion also introduces some inconsistencies when calling the extension methods in non infix form. The user needs to invert the order of the arguments at call site manually. For instance:
80+
81+
```scala
82+
extension [T](x: T)
83+
def *:(xs: List[T]): List[T] = ...
84+
85+
y.*:(ys) // error when following the parameter definition order
86+
ys.*:(y)
87+
88+
*:(y)(ys) // error when following the parameter definition order
89+
*:(ys)(y)
90+
```
6391

64-
An overall simpler design could be obtained if right-associative operators could _only_ be defined as extension methods, and would be disallowed as normal methods. In that case neither arguments nor parameters would have to be swapped. Future versions of Scala should strive to achieve this simplification.
92+
Another limitation of this representation is that it is impossible to pass the
93+
type parameters of the `def` explicitly. For instance:
94+
95+
```scala
96+
extension (x: Int)
97+
def *:[T](xs: List[T]): List[T] = ???
98+
99+
xs.*:[Int](1) // error when trying to set T explicitly
100+
```

library/src/scala/IArray.scala

+8-4
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,14 @@ object IArray:
319319
def zipAll[T1 >: T, U](that: Iterable[U], thisElem: T1, thatElem: U): IArray[(T1, U)] = genericArrayOps(arr).zipAll(that, thisElem, thatElem)
320320
def zipWithIndex: IArray[(T, Int)] = genericArrayOps(arr).zipWithIndex
321321

322-
extension [T, U >: T: ClassTag](arr: IArray[U])
323-
def ++:(prefix: IterableOnce[T]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
324-
def ++:(prefix: IArray[T]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
325-
def +:(x: T): IArray[U] = genericArrayOps(arr).prepended(x)
322+
extension [T, U >: T: ClassTag](prefix: IterableOnce[T])
323+
def ++:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
324+
325+
extension [T, U >: T: ClassTag](prefix: IArray[T])
326+
def ++:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
327+
328+
extension [T, U >: T: ClassTag](x: T)
329+
def +:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prepended(x)
326330

327331
// For backwards compatibility with code compiled without -Yexplicit-nulls
328332
private inline def mapNull[A, B](a: A, inline f: B): B =

presentation-compiler/test/dotty/tools/pc/tests/hover/HoverTypeSuite.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ class HoverTypeSuite extends BaseHoverSuite:
152152
|class C
153153
|
154154
|object Foo:
155-
| extension [R](using A)(res: R)(using B)
156-
| def %:[T](main: T)(using C): R = ???
155+
| extension [T](using A)(main: T)(using B)
156+
| def %:[R](res: R)(using C): R = ???
157157
| given A with {}
158158
| given B with {}
159159
| given C with {}
@@ -162,7 +162,7 @@ class HoverTypeSuite extends BaseHoverSuite:
162162
|end Foo
163163
|""".stripMargin,
164164
"""|Int
165-
|extension [R](using A)(using B)(res: R) def %:[T](main: T)(using C): R""".stripMargin.hover
165+
|extension [T](using A)(main: T) def %:[R](res: R)(using B)(using C): R""".stripMargin.hover
166166
)
167167

168168
@Test def `using` =

tests/neg-custom-args/captures/lazylists-exceptions.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ final class LazyCons[+T](val x: T, val xs: () => LazyList[T]^) extends LazyList[
2020
def tail: LazyList[T]^{this} = xs()
2121
end LazyCons
2222

23-
extension [A](xs1: => LazyList[A]^)
24-
def #:(x: A): LazyList[A]^{xs1} =
23+
extension [A](x: A)
24+
def #:(xs1: => LazyList[A]^): LazyList[A]^{xs1} =
2525
LazyCons(x, () => xs1)
2626

2727
def tabulate[A](n: Int)(gen: Int => A): LazyList[A]^{gen} =

tests/neg/i13075.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ object Implementing_Tuples:
88
type *:[H, T <: Tup] = ConsTup[H, T] // for type matching
99
type EmptyTup = EmptyTup.type // for type matching
1010

11-
extension [T <: Tup](tail: T)
12-
def *:[H](head: H) = ConsTup(head, tail)
11+
extension [H](head: H)
12+
def *:[T <: Tup](tail: T) = ConsTup(head, tail)
1313

1414
type Fold[T <: Tup, Seed, F[_,_]] = T match
1515
case EmptyTup => Seed

tests/neg/i9562.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ object Unrelated:
99
def h1: Int = foo // error
1010
def h2: Int = h1 + 1 // OK
1111
def h3: Int = g // error
12-
extension (x: Int)
13-
def ++:(f: Foo): Int = f.h1 + x // OK
12+
def ++: (x: Int): Int = h1 + x // OK

tests/pos-custom-args/captures/lazylists-exceptions.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ extension [A](xs: LzyList[A]^)
4242
if n == 0 then xs else xs.tail.drop(n - 1)
4343
end extension
4444

45-
extension [A](xs1: => LzyList[A]^)
46-
def #:(x: A): LzyList[A]^{xs1} =
45+
extension [A](x: A)
46+
def #:(xs1: => LzyList[A]^): LzyList[A]^{xs1} =
4747
LzyCons(x, () => xs1)
4848

4949
def lazyCons[A](x: A, xs1: => LzyList[A]^): LzyList[A]^{xs1} =

tests/pos-custom-args/captures/logger.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ final class LazyCons[+T](val x: T, val xs: () => LazyList[T]^) extends LazyList[
3232
def tail: LazyList[T]^{this} = xs()
3333
end LazyCons
3434

35-
extension [A](xs1: => LazyList[A]^)
36-
def #::(x: A): LazyList[A]^{xs1} =
35+
extension [A](x: A)
36+
def #::(xs1: => LazyList[A]^): LazyList[A]^{xs1} =
3737
LazyCons(x, () => xs1)
3838

3939
extension [A](xs: LazyList[A]^)

tests/pos-custom-args/captures/strictlists.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ extension [A](xs: StrictList[A])
2828
def concat(ys: StrictList[A]): StrictList[A] =
2929
if xs.isEmpty then ys
3030
else xs.head #: xs.tail.concat(ys)
31-
32-
def #:(x: A): StrictList[A] =
33-
StrictCons(x, xs)
3431
end extension
3532

33+
extension [A](x: A)
34+
def #:(xs1: StrictList[A]): StrictList[A] =
35+
StrictCons(x, xs1)
36+
3637
def tabulate[A](n: Int)(gen: Int => A) =
3738
def recur(i: Int): StrictList[A] =
3839
if i == n then StrictNil

tests/pos/i19197.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
extension (tuple: Tuple)
2-
def **:[T >: tuple.type <: Tuple, H](x: H): H *: T = ???
2+
infix def **:[T >: tuple.type <: Tuple, H](x: H): H *: T = ???
33

44
def test1: (Int, String, Char) = 1 **: ("a", 'b')
55
def test2: (Int, String, Char) = ("a", 'b').**:(1)

tests/pos/i9562.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,4 @@ object Unrelated:
88
extension (f: Foo)
99
def h1: Int = 0
1010
def h2: Int = h1 + 1 // OK
11-
12-
extension (x: Int)
13-
def ++: (f: Foo): Int = f.h2 + x // OK
11+
def ++: (x: Int): Int = h2 + x // OK

tests/pos/reference/extension-methods.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ object ExtMethods:
1010
assert(circle.circumference == circumference(circle))
1111

1212
extension (x: String) def < (y: String) = x.compareTo(y) < 0
13-
extension [Elem](xs: Seq[Elem]) def #: (x: Elem) = x +: xs
13+
extension [Elem](x: Elem) def #: (xs: Seq[Elem]) = x +: xs
1414
extension (x: Number) infix def min (y: Number) = x
1515

1616
assert("a" < "bb")

tests/run/errorhandling/Result.scala

+1-4
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ object Result:
3535
case (Err(e), Ok(_)) => Err(e :: Nil)
3636
case (Err(e1), Err(e2)) => Err(e1 :: e2 :: Nil)
3737

38-
end extension
39-
40-
extension [U <: Tuple, E](other: Result[U, List[E]])
4138
/** Validate both `r` and `other`; return a tuple of successes or a list of failures.
4239
* Unlike with `zip`, the right hand side `other` must be a `Result` returning a `Tuple`,
4340
* and the left hand side is added to it. See `Result.empty` for a convenient
4441
* right unit of chains of `*:`s.
4542
*/
46-
def *: [T](r: Result[T, E]): Result[T *: U, List[E]] = (r, other) match
43+
def *: [U <: Tuple](other: Result[U, List[E]]): Result[T *: U, List[E]] = (r, other) match
4744
case (Ok(x), Ok(ys)) => Ok(x *: ys)
4845
case (Ok(_), es: Err[?]) => es
4946
case (Err(e), Ok(_)) => Err(e :: Nil)

tests/run/export-in-extension.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ object O:
1111
export cm.*
1212
def succ: Int = x + 1
1313
def succ2: Int = succ + 1
14-
def ::: (y: Int) = y - x
14+
def ::: (y: Int) = x - y
1515

1616
object O2:
1717
import O.C
@@ -20,7 +20,7 @@ object O2:
2020
export cm.{bar, baz, bam, ::}
2121
def succ: Int = x + 1
2222
def succ2: Int = succ + 1
23-
def ::: (y: Int) = y - x
23+
def ::: (y: Int) = x - y
2424

2525
@main def Test =
2626
import O.*

tests/run/i11583.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ class Env:
1212
// */
1313
// def &&:[T <: ctx.Term](trm: T)(ext: env.Extra): (ctx.Type, T, env.Extra) = (tpe, trm, ext)
1414

15-
extension [Ctx <: Context, T <: Boolean](using ctx: Ctx)(trm: T)(using env: Env)
16-
def :#:(tpe: String)(ext: env.Extra): (String, T, env.Extra) = (tpe, trm, ext)
15+
extension [Ctx <: Context](using ctx: Ctx)(tpe: String)(using env: Env)
16+
def :#:[T <: Boolean](trm: T)(ext: env.Extra): (String, T, env.Extra) = (tpe, trm, ext)
1717

18-
extension [T <: Tuple](t: T)
19-
def :*:[A](a: A): A *: T = a *: t
18+
extension [A](a: A)
19+
def :*:[T <: Tuple](t: T): A *: T = a *: t
2020

2121
@main def Test =
2222

tests/run/i9530.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ trait Scope:
88
extension (using s: Scope)(expr: s.Expr)
99
def show = expr.toString
1010
def eval = s.value(expr)
11-
def *: (other: s.Expr) = s.combine(other, expr)
11+
def *: (other: s.Expr) = s.combine(expr, other)
1212

1313
def f(using s: Scope)(x: s.Expr): (String, s.Value) =
1414
(x.show, x.eval)

tests/run/instances.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object Test extends App {
3131
extension [T](xs: List[List[T]])
3232
def flattened = xs.foldLeft[List[T]](Nil)(_ ++ _)
3333

34-
extension [T](xs: Seq[T]) def :: (x: T) = x +: xs
34+
extension [T](x: T) def :: (xs: Seq[T]) = x +: xs
3535

3636
val ss: Seq[Int] = List(1, 2, 3)
3737
val ss1 = 0 :: ss

0 commit comments

Comments
 (0)