Skip to content

Commit 29e229b

Browse files
Inline traits numeric example (#17098)
Rewrite of the Numeric library. This new version uses inlining to improve performance on basic numeric operations.
2 parents f1bc0fb + e9d8718 commit 29e229b

File tree

5 files changed

+347
-0
lines changed

5 files changed

+347
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package scala.math
2+
package inline
3+
4+
trait Fractional[T] extends Numeric[T]:
5+
transparent inline def div(inline x: T, inline y: T): T
6+
protected transparent inline def isNaN(inline x: T): Boolean
7+
protected transparent inline def isNegZero(inline x: T): Boolean
8+
9+
extension (inline x: T)
10+
transparent inline def abs: T =
11+
if lt(x, zero) || isNegZero(x) then negate(x) else x
12+
transparent inline def sign: T =
13+
if isNaN(x) || isNegZero(x) then x
14+
else if lt(x, zero) then negate(one)
15+
else if gt(x, zero) then one
16+
else zero
17+
transparent inline def /(inline y: T) = div(x, y)
18+
19+
object Fractional:
20+
given BigDecimalIsFractional: BigDecimalIsConflicted with Fractional[BigDecimal] with
21+
transparent inline def div(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x / y
22+
23+
protected transparent inline def isNaN(inline x: BigDecimal): Boolean = false
24+
protected transparent inline def isNegZero(inline x: BigDecimal): Boolean = false
25+
26+
given DoubleIsFractional: Fractional[Double] with Ordering.DoubleIeeeOrdering with
27+
transparent inline def plus(inline x: Double, inline y: Double): Double = x + y
28+
transparent inline def minus(inline x: Double, inline y: Double): Double = x - y
29+
transparent inline def times(inline x: Double, inline y: Double): Double = x * y
30+
transparent inline def div(inline x: Double, inline y: Double): Double = x / y
31+
transparent inline def negate(inline x: Double): Double = -x
32+
33+
transparent inline def fromInt(x: Int): Double = x.toDouble
34+
def parseString(str: String): Option[Double] = str.toDoubleOption
35+
36+
protected transparent inline def isNaN(inline x: Double): Boolean = x.isNaN
37+
protected transparent inline def isNegZero(inline x: Double): Boolean = x.equals(-0.0)
38+
39+
extension (inline x: Double)
40+
transparent inline def toInt: Int = x.toInt
41+
transparent inline def toLong: Long = x.toLong
42+
transparent inline def toFloat: Float = x.toFloat
43+
transparent inline def toDouble: Double = x
44+
45+
given FloatIsFractional: Fractional[Float] with Ordering.FloatIeeeOrdering with
46+
transparent inline def plus(inline x: Float, inline y: Float): Float = x + y
47+
transparent inline def minus(inline x: Float, inline y: Float): Float = x - y
48+
transparent inline def times(inline x: Float, inline y: Float): Float = x * y
49+
transparent inline def div(inline x: Float, inline y: Float): Float = x / y
50+
transparent inline def negate(inline x: Float): Float = -x
51+
52+
transparent inline def fromInt(x: Int): Float = x.toFloat
53+
def parseString(str: String): Option[Float] = str.toFloatOption
54+
55+
protected transparent inline def isNaN(inline x: Float): Boolean = x.isNaN
56+
protected transparent inline def isNegZero(inline x: Float): Boolean = x.equals(-0f)
57+
58+
extension (inline x: Float)
59+
transparent inline def toInt: Int = x.toInt
60+
transparent inline def toLong: Long = x.toLong
61+
transparent inline def toFloat: Float = x
62+
transparent inline def toDouble: Double = x.toDouble
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package scala.math
2+
package inline
3+
4+
import scala.util.Try
5+
6+
trait Integral[T] extends Numeric[T]:
7+
inline def quot(inline x: T, inline y: T): T
8+
inline def rem(inline x: T, inline y: T): T
9+
10+
extension (inline x: T)
11+
transparent inline def abs: T =
12+
if lt(x, zero) then negate(x) else x
13+
transparent inline def sign: T =
14+
if lt(x, zero) then negate(one)
15+
else if gt(x, zero) then one
16+
else zero
17+
transparent inline def /(inline y: T) = quot(x, y)
18+
transparent inline def %(inline y: T) = rem(x, y)
19+
transparent inline def /%(inline y: T) = (quot(x, y), rem(x, y))
20+
21+
object Integral:
22+
given BigDecimalAsIfIntegral: Integral[BigDecimal] with BigDecimalIsConflicted with
23+
transparent inline def quot(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x quot y
24+
transparent inline def rem(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x remainder y
25+
26+
given BigIntIsIntegral: Integral[BigInt] with Ordering.BigIntOrdering with
27+
transparent inline def plus(inline x: BigInt, inline y: BigInt): BigInt = x + y
28+
transparent inline def minus(inline x: BigInt, inline y: BigInt): BigInt = x - y
29+
transparent inline def times(inline x: BigInt, inline y: BigInt): BigInt = x * y
30+
transparent inline def negate(inline x: BigInt): BigInt = -x
31+
32+
extension (inline x: BigInt)
33+
transparent inline def toInt: Int = x.intValue
34+
transparent inline def toLong: Long = x.longValue
35+
transparent inline def toFloat: Float = x.floatValue
36+
transparent inline def toDouble: Double = x.doubleValue
37+
38+
transparent inline def fromInt(x: Int): BigInt = BigInt(x)
39+
def parseString(str: String): Option[BigInt] = Try(BigInt(str)).toOption
40+
41+
transparent inline def quot(inline x: BigInt, inline y: BigInt): BigInt = x / y
42+
transparent inline def rem(inline x: BigInt, inline y: BigInt): BigInt = x % y
43+
44+
given ByteIsIntegral: Integral[Byte] with Ordering.ByteOrdering with
45+
transparent inline def plus(inline x: Byte, inline y: Byte): Byte = (x + y).toByte
46+
transparent inline def minus(inline x: Byte, inline y: Byte): Byte = (x - y).toByte
47+
transparent inline def times(inline x: Byte, inline y: Byte): Byte = (x * y).toByte
48+
transparent inline def negate(inline x: Byte): Byte = (-x).toByte
49+
50+
transparent inline def fromInt(x: Int): Byte = x.toByte
51+
def parseString(str: String): Option[Byte] = str.toByteOption
52+
53+
transparent inline def quot(inline x: Byte, inline y: Byte): Byte = (x / y).toByte
54+
transparent inline def rem(inline x: Byte, inline y: Byte): Byte = (x % y).toByte
55+
56+
extension (inline x: Byte)
57+
transparent inline def toInt: Int = x.toInt
58+
transparent inline def toLong: Long = x.toLong
59+
transparent inline def toFloat: Float = x.toFloat
60+
transparent inline def toDouble: Double = x.toDouble
61+
62+
given CharIsIntegral: Integral[Char] with Ordering.CharOrdering with
63+
transparent inline def plus(inline x: Char, inline y: Char): Char = (x + y).toChar
64+
transparent inline def minus(inline x: Char, inline y: Char): Char = (x - y).toChar
65+
transparent inline def times(inline x: Char, inline y: Char): Char = (x * y).toChar
66+
transparent inline def negate(inline x: Char): Char = (-x).toChar
67+
68+
transparent inline def fromInt(x: Int): Char = x.toChar
69+
def parseString(str: String): Option[Char] = Try(str.toInt.toChar).toOption
70+
71+
transparent inline def quot(inline x: Char, inline y: Char): Char = (x / y).toChar
72+
transparent inline def rem(inline x: Char, inline y: Char): Char = (x % y).toChar
73+
74+
extension (inline x: Char)
75+
transparent inline def toInt: Int = x.toInt
76+
transparent inline def toLong: Long = x.toLong
77+
transparent inline def toFloat: Float = x.toFloat
78+
transparent inline def toDouble: Double = x.toDouble
79+
80+
given IntIsIntegral: Integral[Int] with Ordering.IntOrdering with
81+
transparent inline def plus(inline x: Int, inline y: Int): Int = x + y
82+
transparent inline def minus(inline x: Int, inline y: Int): Int = x - y
83+
transparent inline def times(inline x: Int, inline y: Int): Int = x * y
84+
transparent inline def negate(inline x: Int): Int = -x
85+
86+
transparent inline def fromInt(x: Int): Int = x
87+
def parseString(str: String): Option[Int] = str.toIntOption
88+
89+
transparent inline def quot(inline x: Int, inline y: Int): Int = x / y
90+
transparent inline def rem(inline x: Int, inline y: Int): Int = x % y
91+
92+
extension (inline x: Int)
93+
transparent inline def toInt: Int = x
94+
transparent inline def toLong: Long = x.toLong
95+
transparent inline def toFloat: Float = x.toFloat
96+
transparent inline def toDouble: Double = x.toDouble
97+
98+
given LongIsIntegral: Integral[Long] with Ordering.LongOrdering with
99+
transparent inline def plus(inline x: Long, inline y: Long): Long = x + y
100+
transparent inline def minus(inline x: Long, inline y: Long): Long = x - y
101+
transparent inline def times(inline x: Long, inline y: Long): Long = x * y
102+
transparent inline def negate(inline x: Long): Long = -x
103+
104+
transparent inline def fromInt(x: Int): Long = x.toLong
105+
def parseString(str: String): Option[Long] = str.toLongOption
106+
107+
transparent inline def quot(inline x: Long, inline y: Long): Long = (x / y).toLong
108+
transparent inline def rem(inline x: Long, inline y: Long): Long = (x % y).toLong
109+
110+
extension (inline x: Long)
111+
transparent inline def toInt: Int = x.toInt
112+
transparent inline def toLong: Long = x
113+
transparent inline def toFloat: Float = x.toFloat
114+
transparent inline def toDouble: Double = x.toDouble
115+
116+
given ShortIsIntegral: Integral[Short] with Ordering.ShortOrdering with
117+
transparent inline def plus(inline x: Short, inline y: Short): Short = (x + y).toShort
118+
transparent inline def minus(inline x: Short, inline y: Short): Short = (x - y).toShort
119+
transparent inline def times(inline x: Short, inline y: Short): Short = (x * y).toShort
120+
transparent inline def negate(inline x: Short): Short = (-x).toShort
121+
122+
transparent inline def fromInt(x: Int): Short = x.toShort
123+
def parseString(str: String): Option[Short] = str.toShortOption
124+
125+
transparent inline def quot(inline x: Short, inline y: Short): Short = (x / y).toShort
126+
transparent inline def rem(inline x: Short, inline y: Short): Short = (x % y).toShort
127+
128+
extension (inline x: Short)
129+
transparent inline def toInt: Int = x.toInt
130+
transparent inline def toLong: Long = x.toLong
131+
transparent inline def toFloat: Float = x.toFloat
132+
transparent inline def toDouble: Double = x.toDouble
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package scala.math
2+
package inline
3+
4+
import scala.util.Try
5+
6+
trait Numeric[T] extends Ordering[T]:
7+
inline def plus(inline x: T, inline y: T): T
8+
inline def minus(inline x: T, inline y: T): T
9+
inline def times(inline x: T, inline y: T): T
10+
inline def negate(inline x: T): T
11+
12+
def fromInt(x: Int): T
13+
def parseString(str: String): Option[T]
14+
15+
transparent inline def zero = fromInt(0)
16+
transparent inline def one = fromInt(1)
17+
18+
extension (inline x: T)
19+
transparent inline def +(inline y: T): T = plus(x, y)
20+
transparent inline def -(inline y: T) = minus(x, y)
21+
transparent inline def *(inline y: T): T = times(x, y)
22+
transparent inline def unary_- = negate(x)
23+
inline def toInt: Int
24+
inline def toLong: Long
25+
inline def toFloat: Float
26+
inline def toDouble: Double
27+
inline def abs: T
28+
inline def sign: T
29+
30+
trait BigDecimalIsConflicted extends Numeric[BigDecimal] with Ordering.BigDecimalOrdering:
31+
transparent inline def plus(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x + y
32+
transparent inline def minus(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x - y
33+
transparent inline def times(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x * y
34+
transparent inline def negate(inline x: BigDecimal): BigDecimal = -x
35+
36+
transparent inline def fromInt(x: Int): BigDecimal = BigDecimal(x)
37+
def parseString(str: String): Option[BigDecimal] = Try(BigDecimal(str)).toOption
38+
39+
extension (inline x: BigDecimal)
40+
transparent inline def toInt: Int = x.intValue
41+
transparent inline def toLong: Long = x.longValue
42+
transparent inline def toFloat: Float = x.floatValue
43+
transparent inline def toDouble: Double = x.doubleValue
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package scala.math
2+
package inline
3+
4+
import java.util.Comparator
5+
6+
trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable:
7+
outer =>
8+
9+
inline def tryCompare(x: T, y: T) = Some(compare(x, y))
10+
11+
def compare(x: T, y: T): Int
12+
13+
override inline def lteq(x: T, y: T): Boolean = compare(x, y) <= 0
14+
override inline def gteq(x: T, y: T): Boolean = compare(x, y) >= 0
15+
override inline def lt(x: T, y: T): Boolean = compare(x, y) < 0
16+
override inline def gt(x: T, y: T): Boolean = compare(x, y) > 0
17+
override inline def equiv(x: T, y: T): Boolean = compare(x, y) == 0
18+
19+
inline def max(x: T, y: T): T = if gteq(x, y) then x else y
20+
inline def min(x: T, y: T): T = if lteq(x, y) then x else y
21+
22+
// This is made into a separate trait, because defining the reverse ordering
23+
// anonymously results in an error:
24+
// Implementation restriction: nested inline methods are not supported
25+
inline def on[U](f: U => T): Ordering[U] = new ReverseOrdering(f) {}
26+
27+
private trait ReverseOrdering[U](f: U => T) extends Ordering[U]:
28+
inline def compare(x: U, y: U) = outer.compare(f(x), f(y))
29+
30+
object Ordering:
31+
trait BigDecimalOrdering extends Ordering[BigDecimal]:
32+
inline def compare(x: BigDecimal, y: BigDecimal) = x.compare(y)
33+
34+
trait BigIntOrdering extends Ordering[BigInt]:
35+
inline def compare(x: BigInt, y: BigInt) = x.compare(y)
36+
37+
trait ByteOrdering extends Ordering[Byte]:
38+
inline def compare(x: Byte, y: Byte) = java.lang.Byte.compare(x, y)
39+
40+
trait CharOrdering extends Ordering[Char]:
41+
inline def compare(x: Char, y: Char) = java.lang.Character.compare(x, y)
42+
43+
trait IntOrdering extends Ordering[Int]:
44+
inline def compare(x: Int, y: Int) = java.lang.Integer.compare(x, y)
45+
46+
trait LongOrdering extends Ordering[Long]:
47+
inline def compare(x: Long, y: Long) = java.lang.Long.compare(x, y)
48+
49+
trait ShortOrdering extends Ordering[Short]:
50+
inline def compare(x: Short, y: Short) = java.lang.Short.compare(x, y)
51+
52+
trait FloatIeeeOrdering extends Ordering[Float]:
53+
inline def compare(x: Float, y: Float) = java.lang.Float.compare(x, y)
54+
55+
trait DoubleIeeeOrdering extends Ordering[Double]:
56+
inline def compare(x: Double, y: Double) = java.lang.Double.compare(x, y)

tests/run/inline-numeric/test.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import scala.math.inline.*
2+
import scala.math.inline.Ordering.*
3+
import scala.math.inline.Integral.given
4+
import scala.math.inline.Fractional.given
5+
6+
object tests:
7+
inline def foo[T: Numeric](inline a: T, inline b: T) =
8+
a + b * b
9+
10+
inline def div[T: Integral](inline a: T, inline b: T) =
11+
a / b % b
12+
13+
inline def div[T: Fractional](inline a: T, inline b: T) =
14+
a / b + a
15+
16+
inline def toInt[T: Numeric](inline a: T) =
17+
a.toInt
18+
19+
inline def explicitToInt[T](inline a: T)(using n: Numeric[T]) =
20+
n.toInt(a)
21+
22+
inline def sign[T: Numeric](inline a: T) =
23+
a.sign
24+
25+
inline def explicitPlus[T](inline a: T, inline b: T)(using n: Numeric[T]) =
26+
n.plus(a, b)
27+
28+
@main def Test =
29+
def a: Int = 0
30+
def b: Int = 1
31+
32+
val v1 = foo(a, b) // should be a + b * b // can check with -Xprint:inlining
33+
val v2 = foo(a.toShort, b.toShort) // should be a + b * b
34+
35+
val v3 = div(BigDecimal(a), BigDecimal(b))(using BigDecimalAsIfIntegral) // should be BigDecimal(a) quot BigDecimal(b) remainder BigDecimal(b)
36+
val v4 = div(BigDecimal(a), BigDecimal(b))(using BigDecimalIsFractional) // should be BigDecimal(a) / BigDecimal(b) + BigDecimal(a)
37+
38+
val v5 = toInt(a.toFloat) // should be a.toFloat.toInt
39+
val v6 = toInt(a) // should be a
40+
41+
val v7 = sign(a)
42+
val v8 = sign(a.toChar)
43+
val v9 = sign(-7F)
44+
45+
val v10 = sign(BigDecimal(a))(using BigDecimalAsIfIntegral)
46+
val v11 = sign(BigDecimal(a))(using BigDecimalIsFractional) // the condition with isNan() should be removed, i.e. it should be equivalent to v10
47+
48+
val v12 = explicitPlus(3, 5) // should be 8
49+
val v13 = explicitPlus(a, b) // should be a + b
50+
51+
val v14 = explicitToInt(3.2) // should be (3.2).toInt
52+
val v15 = explicitToInt(3) // should be 3
53+
val v16 = explicitToInt(a) // should be a
54+
val v17 = explicitToInt(a.toShort) // should be a.toShort.toInt

0 commit comments

Comments
 (0)