|
| 1 | +import scala.quoted._ |
| 2 | + |
| 3 | +trait Ring[T] { |
| 4 | + val zero: T |
| 5 | + val one: T |
| 6 | + val add: (x: T, y: T) => T |
| 7 | + val sub: (x: T, y: T) => T |
| 8 | + val mul: (x: T, y: T) => T |
| 9 | +} |
| 10 | + |
| 11 | +class RingInt extends Ring[Int] { |
| 12 | + val zero = 0 |
| 13 | + val one = 1 |
| 14 | + val add = (x, y) => x + y |
| 15 | + val sub = (x, y) => x - y |
| 16 | + val mul = (x, y) => x * y |
| 17 | +} |
| 18 | + |
| 19 | +class RingIntExpr extends Ring[Expr[Int]] { |
| 20 | + val zero = '(0) |
| 21 | + val one = '(1) |
| 22 | + val add = (x, y) => '(~x + ~y) |
| 23 | + val sub = (x, y) => '(~x - ~y) |
| 24 | + val mul = (x, y) => '(~x * ~y) |
| 25 | +} |
| 26 | + |
| 27 | +class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] { |
| 28 | + val zero = Complex(u.zero, u.zero) |
| 29 | + val one = Complex(u.one, u.zero) |
| 30 | + val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im)) |
| 31 | + val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im)) |
| 32 | + val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re))) |
| 33 | +} |
| 34 | + |
| 35 | +sealed trait PV[T] { |
| 36 | + def expr(implicit l: Liftable[T]): Expr[T] |
| 37 | +} |
| 38 | +case class Sta[T](x: T) extends PV[T] { |
| 39 | + def expr(implicit l: Liftable[T]): Expr[T] = x.toExpr |
| 40 | +} |
| 41 | +case class Dyn[T](x: Expr[T]) extends PV[T] { |
| 42 | + def expr(implicit l: Liftable[T]): Expr[T] = x |
| 43 | +} |
| 44 | + |
| 45 | +class RingPV[U: Liftable](u: Ring[U], eu: Ring[Expr[U]]) extends Ring[PV[U]] { |
| 46 | + val zero: PV[U] = Sta(u.zero) |
| 47 | + val one: PV[U] = Sta(u.one) |
| 48 | + val add = (x: PV[U], y: PV[U]) => (x, y) match { |
| 49 | + case (Sta(u.zero), x) => x |
| 50 | + case (x, Sta(u.zero)) => x |
| 51 | + case (Sta(x), Sta(y)) => Sta(u.add(x, y)) |
| 52 | + case (x, y) => Dyn(eu.add(x.expr, y.expr)) |
| 53 | + } |
| 54 | + val sub = (x: PV[U], y: PV[U]) => (x, y) match { |
| 55 | + case (x, Sta(u.zero)) => x |
| 56 | + case (Sta(x), Sta(y)) => Sta(u.sub(x, y)) |
| 57 | + case (x, y) => Dyn(eu.sub(x.expr, y.expr)) |
| 58 | + } |
| 59 | + val mul = (x: PV[U], y: PV[U]) => (x, y) match { |
| 60 | + case (Sta(u.zero), _) => Sta(u.zero) |
| 61 | + case (_, Sta(u.zero)) => Sta(u.zero) |
| 62 | + case (Sta(u.one), x) => x |
| 63 | + case (x, Sta(u.one)) => x |
| 64 | + case (Sta(x), Sta(y)) => Sta(u.mul(x, y)) |
| 65 | + case (x, y) => Dyn(eu.mul(x.expr, y.expr)) |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | + |
| 70 | +case class Complex[T](re: T, im: T) |
| 71 | + |
| 72 | +object Complex { |
| 73 | + implicit def isLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable[Complex[T]] { |
| 74 | + def toExpr(comp: Complex[T]): Expr[Complex[T]] = '(Complex(~comp.re.toExpr, ~comp.im.toExpr)) |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +case class Vec[Idx, T](size: Idx, get: Idx => T) { |
| 79 | + def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i))) |
| 80 | + def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i))) |
| 81 | +} |
| 82 | + |
| 83 | +object Vec { |
| 84 | + def from[T](elems: T*): Vec[Int, T] = new Vec(elems.size, i => elems(i)) |
| 85 | +} |
| 86 | + |
| 87 | +trait VecOps[Idx, T] { |
| 88 | + val reduce: ((T, T) => T, T, Vec[Idx, T]) => T |
| 89 | +} |
| 90 | + |
| 91 | +class StaticVecOps[T] extends VecOps[Int, T] { |
| 92 | + val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => { |
| 93 | + var sum = zero |
| 94 | + for (i <- 0 until vec.size) |
| 95 | + sum = plus(sum, vec.get(i)) |
| 96 | + sum |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] { |
| 101 | + val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{ |
| 102 | + var sum = ~zero |
| 103 | + var i = 0 |
| 104 | + while (i < ~vec.size) { |
| 105 | + sum = ~{ plus('(sum), vec.get('(i))) } |
| 106 | + i += 1 |
| 107 | + } |
| 108 | + sum |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +class Blas1[Idx, T](r: Ring[T], ops: VecOps[Idx, T]) { |
| 113 | + def dot(v1: Vec[Idx, T], v2: Vec[Idx, T]): T = ops.reduce(r.add, r.zero, v1.zipWith(v2, r.mul)) |
| 114 | +} |
| 115 | + |
| 116 | +object Test { |
| 117 | + |
| 118 | + implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make |
| 119 | + |
| 120 | + def main(args: Array[String]): Unit = { |
| 121 | + val arr1 = Array(0, 1, 2, 4, 8) |
| 122 | + val arr2 = Array(1, 0, 1, 0, 1) |
| 123 | + val cmpxArr1 = Array(Complex(1, 0), Complex(2, 3), Complex(0, 2), Complex(3, 1)) |
| 124 | + val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 1), Complex(2, 0)) |
| 125 | + |
| 126 | + val vec1 = new Vec(arr1.size, i => arr1(i)) |
| 127 | + val vec2 = new Vec(arr2.size, i => arr2(i)) |
| 128 | + val cmpxVec1 = new Vec(cmpxArr1.size, i => cmpxArr1(i)) |
| 129 | + val cmpxVec2 = new Vec(cmpxArr2.size, i => cmpxArr2(i)) |
| 130 | + |
| 131 | + val blasInt = new Blas1(new RingInt, new StaticVecOps) |
| 132 | + val res1 = blasInt.dot(vec1, vec2) |
| 133 | + println(res1) |
| 134 | + println() |
| 135 | + |
| 136 | + val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps) |
| 137 | + val res2 = blasComplexInt.dot( |
| 138 | + cmpxVec1, |
| 139 | + cmpxVec2 |
| 140 | + ) |
| 141 | + println(res2) |
| 142 | + println() |
| 143 | + |
| 144 | + val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps) |
| 145 | + val resCode1 = blasStaticIntExpr.dot( |
| 146 | + vec1.map(_.toExpr), |
| 147 | + vec2.map(_.toExpr) |
| 148 | + ) |
| 149 | + println(resCode1.show) |
| 150 | + println(resCode1.run) |
| 151 | + println() |
| 152 | + |
| 153 | + val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps) |
| 154 | + val resCode2: Expr[(Array[Int], Array[Int]) => Int] = '{ |
| 155 | + (arr1, arr2) => |
| 156 | + if (arr1.length != arr2.length) throw new Exception("...") |
| 157 | + ~{ |
| 158 | + blasExprIntExpr.dot( |
| 159 | + new Vec('(arr1.size), i => '(arr1(~i))), |
| 160 | + new Vec('(arr2.size), i => '(arr2(~i))) |
| 161 | + ) |
| 162 | + } |
| 163 | + } |
| 164 | + println(resCode2.show) |
| 165 | + println(resCode2.run.apply(arr1, arr2)) |
| 166 | + println() |
| 167 | + |
| 168 | + val blasStaticIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) |
| 169 | + val resCode3 = blasStaticIntPVExpr.dot( |
| 170 | + vec1.map(i => Dyn(i.toExpr)), |
| 171 | + vec2.map(i => Sta(i)) |
| 172 | + ).expr |
| 173 | + println(resCode3.show) |
| 174 | + println(resCode3.run) |
| 175 | + println() |
| 176 | + |
| 177 | + val blasExprIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) |
| 178 | + val resCode4: Expr[Array[Int] => Int] = '{ |
| 179 | + arr => |
| 180 | + if (arr.length != ~vec2.size.toExpr) throw new Exception("...") |
| 181 | + ~{ |
| 182 | + blasExprIntPVExpr.dot( |
| 183 | + new Vec(vec2.size, i => Dyn('(arr(~i.toExpr)))), |
| 184 | + vec2.map(i => Sta(i)) |
| 185 | + ).expr |
| 186 | + } |
| 187 | + |
| 188 | + } |
| 189 | + println(resCode4.show) |
| 190 | + println(resCode4.run.apply(arr1)) |
| 191 | + println() |
| 192 | + |
| 193 | + import Complex.isLiftable |
| 194 | + val blasExprComplexPVInt = new Blas1[Int, Complex[PV[Int]]](new RingComplex(new RingPV[Int](new RingInt, new RingIntExpr)), new StaticVecOps) |
| 195 | + val resCode5: Expr[Array[Complex[Int]] => Complex[Int]] = '{ |
| 196 | + arr => |
| 197 | + if (arr.length != ~cmpxVec2.size.toExpr) throw new Exception("...") |
| 198 | + ~{ |
| 199 | + val cpx = blasExprComplexPVInt.dot( |
| 200 | + new Vec(cmpxVec2.size, i => Complex(Dyn('(arr(~i.toExpr).re)), Dyn('(arr(~i.toExpr).im)))), |
| 201 | + new Vec(cmpxVec2.size, i => Complex(Sta(cmpxVec2.get(i).re), Sta(cmpxVec2.get(i).im))) |
| 202 | + ) |
| 203 | + '(Complex(~cpx.re.expr, ~cpx.im.expr)) |
| 204 | + } |
| 205 | + } |
| 206 | + println(resCode5.show) |
| 207 | + println(resCode5.run.apply(cmpxArr1)) |
| 208 | + println() |
| 209 | + |
| 210 | + val RingPVInt = new RingPV[Int](new RingInt, new RingIntExpr) |
| 211 | + // Staged loop of dot product on vectors of Int or Expr[Int] |
| 212 | + val dotIntOptExpr = new Blas1(RingPVInt, new StaticVecOps).dot |
| 213 | + // will generate the code '{ ((arr: scala.Array[scala.Int]) => arr.apply(1).+(arr.apply(3))) } |
| 214 | + val staticVec = Vec[Int, PV[Int]](5, i => Sta((i % 2))) |
| 215 | + val code = '{(arr: Array[Int]) => ~dotIntOptExpr(Vec(5, i => Dyn('(arr(~i.toExpr)))), staticVec).expr } |
| 216 | + println(code.show) |
| 217 | + println() |
| 218 | + } |
| 219 | + |
| 220 | +} |
0 commit comments