Skip to content

Inline traits numeric example #17098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 14, 2023
Merged
62 changes: 62 additions & 0 deletions tests/run/inline-numeric/Fractional.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package scala.math
package inline

trait Fractional[T] extends Numeric[T]:
transparent inline def div(inline x: T, inline y: T): T
protected transparent inline def isNaN(inline x: T): Boolean
protected transparent inline def isNegZero(inline x: T): Boolean

extension (inline x: T)
transparent inline def abs: T =
if lt(x, zero) || isNegZero(x) then negate(x) else x
transparent inline def sign: T =
if isNaN(x) || isNegZero(x) then x
else if lt(x, zero) then negate(one)
else if gt(x, zero) then one
else zero
transparent inline def /(inline y: T) = div(x, y)

object Fractional:
given BigDecimalIsFractional: BigDecimalIsConflicted with Fractional[BigDecimal] with
transparent inline def div(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x / y

protected transparent inline def isNaN(inline x: BigDecimal): Boolean = false
protected transparent inline def isNegZero(inline x: BigDecimal): Boolean = false

given DoubleIsFractional: Fractional[Double] with Ordering.DoubleIeeeOrdering with
transparent inline def plus(inline x: Double, inline y: Double): Double = x + y
transparent inline def minus(inline x: Double, inline y: Double): Double = x - y
transparent inline def times(inline x: Double, inline y: Double): Double = x * y
transparent inline def div(inline x: Double, inline y: Double): Double = x / y
transparent inline def negate(inline x: Double): Double = -x

transparent inline def fromInt(x: Int): Double = x.toDouble
def parseString(str: String): Option[Double] = str.toDoubleOption

protected transparent inline def isNaN(inline x: Double): Boolean = x.isNaN
protected transparent inline def isNegZero(inline x: Double): Boolean = x.equals(-0.0)

extension (inline x: Double)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x

given FloatIsFractional: Fractional[Float] with Ordering.FloatIeeeOrdering with
transparent inline def plus(inline x: Float, inline y: Float): Float = x + y
transparent inline def minus(inline x: Float, inline y: Float): Float = x - y
transparent inline def times(inline x: Float, inline y: Float): Float = x * y
transparent inline def div(inline x: Float, inline y: Float): Float = x / y
transparent inline def negate(inline x: Float): Float = -x

transparent inline def fromInt(x: Int): Float = x.toFloat
def parseString(str: String): Option[Float] = str.toFloatOption

protected transparent inline def isNaN(inline x: Float): Boolean = x.isNaN
protected transparent inline def isNegZero(inline x: Float): Boolean = x.equals(-0f)

extension (inline x: Float)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x
transparent inline def toDouble: Double = x.toDouble
132 changes: 132 additions & 0 deletions tests/run/inline-numeric/Integral.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package scala.math
package inline

import scala.util.Try

trait Integral[T] extends Numeric[T]:
inline def quot(inline x: T, inline y: T): T
inline def rem(inline x: T, inline y: T): T

extension (inline x: T)
transparent inline def abs: T =
if lt(x, zero) then negate(x) else x
transparent inline def sign: T =
if lt(x, zero) then negate(one)
else if gt(x, zero) then one
else zero
transparent inline def /(inline y: T) = quot(x, y)
transparent inline def %(inline y: T) = rem(x, y)
transparent inline def /%(inline y: T) = (quot(x, y), rem(x, y))

object Integral:
given BigDecimalAsIfIntegral: Integral[BigDecimal] with BigDecimalIsConflicted with
transparent inline def quot(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x quot y
transparent inline def rem(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x remainder y

given BigIntIsIntegral: Integral[BigInt] with Ordering.BigIntOrdering with
transparent inline def plus(inline x: BigInt, inline y: BigInt): BigInt = x + y
transparent inline def minus(inline x: BigInt, inline y: BigInt): BigInt = x - y
transparent inline def times(inline x: BigInt, inline y: BigInt): BigInt = x * y
transparent inline def negate(inline x: BigInt): BigInt = -x

extension (inline x: BigInt)
transparent inline def toInt: Int = x.intValue
transparent inline def toLong: Long = x.longValue
transparent inline def toFloat: Float = x.floatValue
transparent inline def toDouble: Double = x.doubleValue

transparent inline def fromInt(x: Int): BigInt = BigInt(x)
def parseString(str: String): Option[BigInt] = Try(BigInt(str)).toOption

transparent inline def quot(inline x: BigInt, inline y: BigInt): BigInt = x / y
transparent inline def rem(inline x: BigInt, inline y: BigInt): BigInt = x % y

given ByteIsIntegral: Integral[Byte] with Ordering.ByteOrdering with
transparent inline def plus(inline x: Byte, inline y: Byte): Byte = (x + y).toByte
transparent inline def minus(inline x: Byte, inline y: Byte): Byte = (x - y).toByte
transparent inline def times(inline x: Byte, inline y: Byte): Byte = (x * y).toByte
transparent inline def negate(inline x: Byte): Byte = (-x).toByte

transparent inline def fromInt(x: Int): Byte = x.toByte
def parseString(str: String): Option[Byte] = str.toByteOption

transparent inline def quot(inline x: Byte, inline y: Byte): Byte = (x / y).toByte
transparent inline def rem(inline x: Byte, inline y: Byte): Byte = (x % y).toByte

extension (inline x: Byte)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x.toDouble

given CharIsIntegral: Integral[Char] with Ordering.CharOrdering with
transparent inline def plus(inline x: Char, inline y: Char): Char = (x + y).toChar
transparent inline def minus(inline x: Char, inline y: Char): Char = (x - y).toChar
transparent inline def times(inline x: Char, inline y: Char): Char = (x * y).toChar
transparent inline def negate(inline x: Char): Char = (-x).toChar

transparent inline def fromInt(x: Int): Char = x.toChar
def parseString(str: String): Option[Char] = Try(str.toInt.toChar).toOption

transparent inline def quot(inline x: Char, inline y: Char): Char = (x / y).toChar
transparent inline def rem(inline x: Char, inline y: Char): Char = (x % y).toChar

extension (inline x: Char)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x.toDouble

given IntIsIntegral: Integral[Int] with Ordering.IntOrdering with
transparent inline def plus(inline x: Int, inline y: Int): Int = x + y
transparent inline def minus(inline x: Int, inline y: Int): Int = x - y
transparent inline def times(inline x: Int, inline y: Int): Int = x * y
transparent inline def negate(inline x: Int): Int = -x

transparent inline def fromInt(x: Int): Int = x
def parseString(str: String): Option[Int] = str.toIntOption

transparent inline def quot(inline x: Int, inline y: Int): Int = x / y
transparent inline def rem(inline x: Int, inline y: Int): Int = x % y

extension (inline x: Int)
transparent inline def toInt: Int = x
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x.toDouble

given LongIsIntegral: Integral[Long] with Ordering.LongOrdering with
transparent inline def plus(inline x: Long, inline y: Long): Long = x + y
transparent inline def minus(inline x: Long, inline y: Long): Long = x - y
transparent inline def times(inline x: Long, inline y: Long): Long = x * y
transparent inline def negate(inline x: Long): Long = -x

transparent inline def fromInt(x: Int): Long = x.toLong
def parseString(str: String): Option[Long] = str.toLongOption

transparent inline def quot(inline x: Long, inline y: Long): Long = (x / y).toLong
transparent inline def rem(inline x: Long, inline y: Long): Long = (x % y).toLong

extension (inline x: Long)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x.toDouble

given ShortIsIntegral: Integral[Short] with Ordering.ShortOrdering with
transparent inline def plus(inline x: Short, inline y: Short): Short = (x + y).toShort
transparent inline def minus(inline x: Short, inline y: Short): Short = (x - y).toShort
transparent inline def times(inline x: Short, inline y: Short): Short = (x * y).toShort
transparent inline def negate(inline x: Short): Short = (-x).toShort

transparent inline def fromInt(x: Int): Short = x.toShort
def parseString(str: String): Option[Short] = str.toShortOption

transparent inline def quot(inline x: Short, inline y: Short): Short = (x / y).toShort
transparent inline def rem(inline x: Short, inline y: Short): Short = (x % y).toShort

extension (inline x: Short)
transparent inline def toInt: Int = x.toInt
transparent inline def toLong: Long = x.toLong
transparent inline def toFloat: Float = x.toFloat
transparent inline def toDouble: Double = x.toDouble
43 changes: 43 additions & 0 deletions tests/run/inline-numeric/Numeric.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package scala.math
package inline

import scala.util.Try

trait Numeric[T] extends Ordering[T]:
inline def plus(inline x: T, inline y: T): T
inline def minus(inline x: T, inline y: T): T
inline def times(inline x: T, inline y: T): T
inline def negate(inline x: T): T

def fromInt(x: Int): T
def parseString(str: String): Option[T]

transparent inline def zero = fromInt(0)
transparent inline def one = fromInt(1)

extension (inline x: T)
transparent inline def +(inline y: T): T = plus(x, y)
transparent inline def -(inline y: T) = minus(x, y)
transparent inline def *(inline y: T): T = times(x, y)
transparent inline def unary_- = negate(x)
inline def toInt: Int
inline def toLong: Long
inline def toFloat: Float
inline def toDouble: Double
inline def abs: T
inline def sign: T

trait BigDecimalIsConflicted extends Numeric[BigDecimal] with Ordering.BigDecimalOrdering:
transparent inline def plus(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x + y
transparent inline def minus(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x - y
transparent inline def times(inline x: BigDecimal, inline y: BigDecimal): BigDecimal = x * y
transparent inline def negate(inline x: BigDecimal): BigDecimal = -x

transparent inline def fromInt(x: Int): BigDecimal = BigDecimal(x)
def parseString(str: String): Option[BigDecimal] = Try(BigDecimal(str)).toOption

extension (inline x: BigDecimal)
transparent inline def toInt: Int = x.intValue
transparent inline def toLong: Long = x.longValue
transparent inline def toFloat: Float = x.floatValue
transparent inline def toDouble: Double = x.doubleValue
56 changes: 56 additions & 0 deletions tests/run/inline-numeric/Ordering.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package scala.math
package inline

import java.util.Comparator

trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable:
outer =>

inline def tryCompare(x: T, y: T) = Some(compare(x, y))

def compare(x: T, y: T): Int

override inline def lteq(x: T, y: T): Boolean = compare(x, y) <= 0
override inline def gteq(x: T, y: T): Boolean = compare(x, y) >= 0
override inline def lt(x: T, y: T): Boolean = compare(x, y) < 0
override inline def gt(x: T, y: T): Boolean = compare(x, y) > 0
override inline def equiv(x: T, y: T): Boolean = compare(x, y) == 0

inline def max(x: T, y: T): T = if gteq(x, y) then x else y
inline def min(x: T, y: T): T = if lteq(x, y) then x else y

// This is made into a separate trait, because defining the reverse ordering
// anonymously results in an error:
// Implementation restriction: nested inline methods are not supported
inline def on[U](f: U => T): Ordering[U] = new ReverseOrdering(f) {}

private trait ReverseOrdering[U](f: U => T) extends Ordering[U]:
inline def compare(x: U, y: U) = outer.compare(f(x), f(y))

object Ordering:
trait BigDecimalOrdering extends Ordering[BigDecimal]:
inline def compare(x: BigDecimal, y: BigDecimal) = x.compare(y)

trait BigIntOrdering extends Ordering[BigInt]:
inline def compare(x: BigInt, y: BigInt) = x.compare(y)

trait ByteOrdering extends Ordering[Byte]:
inline def compare(x: Byte, y: Byte) = java.lang.Byte.compare(x, y)

trait CharOrdering extends Ordering[Char]:
inline def compare(x: Char, y: Char) = java.lang.Character.compare(x, y)

trait IntOrdering extends Ordering[Int]:
inline def compare(x: Int, y: Int) = java.lang.Integer.compare(x, y)

trait LongOrdering extends Ordering[Long]:
inline def compare(x: Long, y: Long) = java.lang.Long.compare(x, y)

trait ShortOrdering extends Ordering[Short]:
inline def compare(x: Short, y: Short) = java.lang.Short.compare(x, y)

trait FloatIeeeOrdering extends Ordering[Float]:
inline def compare(x: Float, y: Float) = java.lang.Float.compare(x, y)

trait DoubleIeeeOrdering extends Ordering[Double]:
inline def compare(x: Double, y: Double) = java.lang.Double.compare(x, y)
54 changes: 54 additions & 0 deletions tests/run/inline-numeric/test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import scala.math.inline.*
import scala.math.inline.Ordering.*
import scala.math.inline.Integral.given
import scala.math.inline.Fractional.given

object tests:
inline def foo[T: Numeric](inline a: T, inline b: T) =
a + b * b

inline def div[T: Integral](inline a: T, inline b: T) =
a / b % b

inline def div[T: Fractional](inline a: T, inline b: T) =
a / b + a

inline def toInt[T: Numeric](inline a: T) =
a.toInt

inline def explicitToInt[T](inline a: T)(using n: Numeric[T]) =
n.toInt(a)

inline def sign[T: Numeric](inline a: T) =
a.sign

inline def explicitPlus[T](inline a: T, inline b: T)(using n: Numeric[T]) =
n.plus(a, b)

@main def Test =
def a: Int = 0
def b: Int = 1

val v1 = foo(a, b) // should be a + b * b // can check with -Xprint:inlining
val v2 = foo(a.toShort, b.toShort) // should be a + b * b

val v3 = div(BigDecimal(a), BigDecimal(b))(using BigDecimalAsIfIntegral) // should be BigDecimal(a) quot BigDecimal(b) remainder BigDecimal(b)
val v4 = div(BigDecimal(a), BigDecimal(b))(using BigDecimalIsFractional) // should be BigDecimal(a) / BigDecimal(b) + BigDecimal(a)

val v5 = toInt(a.toFloat) // should be a.toFloat.toInt
val v6 = toInt(a) // should be a

val v7 = sign(a)
val v8 = sign(a.toChar)
val v9 = sign(-7F)

val v10 = sign(BigDecimal(a))(using BigDecimalAsIfIntegral)
val v11 = sign(BigDecimal(a))(using BigDecimalIsFractional) // the condition with isNan() should be removed, i.e. it should be equivalent to v10

val v12 = explicitPlus(3, 5) // should be 8
val v13 = explicitPlus(a, b) // should be a + b

val v14 = explicitToInt(3.2) // should be (3.2).toInt
val v15 = explicitToInt(3) // should be 3
val v16 = explicitToInt(a) // should be a
val v17 = explicitToInt(a.toShort) // should be a.toShort.toInt