diff --git a/library/src/scala/tasty/util/ShowSourceCode.scala b/library/src/scala/tasty/util/ShowSourceCode.scala index c3288ae3e7f0..e8abb6a361a9 100644 --- a/library/src/scala/tasty/util/ShowSourceCode.scala +++ b/library/src/scala/tasty/util/ShowSourceCode.scala @@ -369,6 +369,7 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty expr match { case Term.Lambda(_, _) => // Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)} + assert(stats.size == 1) val DefDef(_, _, args :: Nil, _, Some(rhs)) :: Nil = stats inParens { printArgsDefs(args) diff --git a/tests/run-with-compiler/shonan-hmm-simple.check b/tests/run-with-compiler/shonan-hmm-simple.check new file mode 100644 index 000000000000..420f66c28728 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm-simple.check @@ -0,0 +1,36 @@ +10 + +Complex(4,3) + +0.+(0.*(1)).+(1.*(0)).+(2.*(1)).+(4.*(0)).+(8.*(1)) +10 + +((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => { + if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else () + var sum: scala.Int = 0 + var i: scala.Int = 0 + while (i.<(scala.Predef.intArrayOps(arr1).size)) { + sum = sum.+(arr1.apply(i).*(arr2.apply(i))) + i = i.+(1) + } + (sum: scala.Int) +}) +10 + +0.+(2).+(8) +10 + +((arr: scala.Array[scala.Int]) => { + if (arr.length.!=(5)) throw new scala.Exception("...") else () + arr.apply(0).+(arr.apply(2)).+(arr.apply(4)) +}) +10 + +((arr: scala.Array[Complex[scala.Int]]) => { + if (arr.length.!=(4)) throw new scala.Exception("...") else () + Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2))) +}) +Complex(4,3) + +((arr: scala.Array[scala.Int]) => arr.apply(1).+(arr.apply(3))) + diff --git a/tests/run-with-compiler/shonan-hmm-simple.scala b/tests/run-with-compiler/shonan-hmm-simple.scala new file mode 100644 index 000000000000..2beb204e1694 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm-simple.scala @@ -0,0 +1,220 @@ +import scala.quoted._ + +trait Ring[T] { + val zero: T + val one: T + val add: (x: T, y: T) => T + val sub: (x: T, y: T) => T + val mul: (x: T, y: T) => T +} + +class RingInt extends Ring[Int] { + val zero = 0 + val one = 1 + val add = (x, y) => x + y + val sub = (x, y) => x - y + val mul = (x, y) => x * y +} + +class RingIntExpr extends Ring[Expr[Int]] { + val zero = '(0) + val one = '(1) + val add = (x, y) => '(~x + ~y) + val sub = (x, y) => '(~x - ~y) + val mul = (x, y) => '(~x * ~y) +} + +class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] { + val zero = Complex(u.zero, u.zero) + val one = Complex(u.one, u.zero) + val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im)) + val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im)) + val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re))) +} + +sealed trait PV[T] { + def expr(implicit l: Liftable[T]): Expr[T] +} +case class Sta[T](x: T) extends PV[T] { + def expr(implicit l: Liftable[T]): Expr[T] = x.toExpr +} +case class Dyn[T](x: Expr[T]) extends PV[T] { + def expr(implicit l: Liftable[T]): Expr[T] = x +} + +class RingPV[U: Liftable](u: Ring[U], eu: Ring[Expr[U]]) extends Ring[PV[U]] { + val zero: PV[U] = Sta(u.zero) + val one: PV[U] = Sta(u.one) + val add = (x: PV[U], y: PV[U]) => (x, y) match { + case (Sta(u.zero), x) => x + case (x, Sta(u.zero)) => x + case (Sta(x), Sta(y)) => Sta(u.add(x, y)) + case (x, y) => Dyn(eu.add(x.expr, y.expr)) + } + val sub = (x: PV[U], y: PV[U]) => (x, y) match { + case (x, Sta(u.zero)) => x + case (Sta(x), Sta(y)) => Sta(u.sub(x, y)) + case (x, y) => Dyn(eu.sub(x.expr, y.expr)) + } + val mul = (x: PV[U], y: PV[U]) => (x, y) match { + case (Sta(u.zero), _) => Sta(u.zero) + case (_, Sta(u.zero)) => Sta(u.zero) + case (Sta(u.one), x) => x + case (x, Sta(u.one)) => x + case (Sta(x), Sta(y)) => Sta(u.mul(x, y)) + case (x, y) => Dyn(eu.mul(x.expr, y.expr)) + } +} + + +case class Complex[T](re: T, im: T) + +object Complex { + implicit def isLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable[Complex[T]] { + def toExpr(comp: Complex[T]): Expr[Complex[T]] = '(Complex(~comp.re.toExpr, ~comp.im.toExpr)) + } +} + +case class Vec[Idx, T](size: Idx, get: Idx => T) { + def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i))) + def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i))) +} + +object Vec { + def from[T](elems: T*): Vec[Int, T] = new Vec(elems.size, i => elems(i)) +} + +trait VecOps[Idx, T] { + val reduce: ((T, T) => T, T, Vec[Idx, T]) => T +} + +class StaticVecOps[T] extends VecOps[Int, T] { + val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => { + var sum = zero + for (i <- 0 until vec.size) + sum = plus(sum, vec.get(i)) + sum + } +} + +class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] { + val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{ + var sum = ~zero + var i = 0 + while (i < ~vec.size) { + sum = ~{ plus('(sum), vec.get('(i))) } + i += 1 + } + sum + } +} + +class Blas1[Idx, T](r: Ring[T], ops: VecOps[Idx, T]) { + def dot(v1: Vec[Idx, T], v2: Vec[Idx, T]): T = ops.reduce(r.add, r.zero, v1.zipWith(v2, r.mul)) +} + +object Test { + + implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make + + def main(args: Array[String]): Unit = { + val arr1 = Array(0, 1, 2, 4, 8) + val arr2 = Array(1, 0, 1, 0, 1) + val cmpxArr1 = Array(Complex(1, 0), Complex(2, 3), Complex(0, 2), Complex(3, 1)) + val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 1), Complex(2, 0)) + + val vec1 = new Vec(arr1.size, i => arr1(i)) + val vec2 = new Vec(arr2.size, i => arr2(i)) + val cmpxVec1 = new Vec(cmpxArr1.size, i => cmpxArr1(i)) + val cmpxVec2 = new Vec(cmpxArr2.size, i => cmpxArr2(i)) + + val blasInt = new Blas1(new RingInt, new StaticVecOps) + val res1 = blasInt.dot(vec1, vec2) + println(res1) + println() + + val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps) + val res2 = blasComplexInt.dot( + cmpxVec1, + cmpxVec2 + ) + println(res2) + println() + + val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps) + val resCode1 = blasStaticIntExpr.dot( + vec1.map(_.toExpr), + vec2.map(_.toExpr) + ) + println(resCode1.show) + println(resCode1.run) + println() + + val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps) + val resCode2: Expr[(Array[Int], Array[Int]) => Int] = '{ + (arr1, arr2) => + if (arr1.length != arr2.length) throw new Exception("...") + ~{ + blasExprIntExpr.dot( + new Vec('(arr1.size), i => '(arr1(~i))), + new Vec('(arr2.size), i => '(arr2(~i))) + ) + } + } + println(resCode2.show) + println(resCode2.run.apply(arr1, arr2)) + println() + + val blasStaticIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) + val resCode3 = blasStaticIntPVExpr.dot( + vec1.map(i => Dyn(i.toExpr)), + vec2.map(i => Sta(i)) + ).expr + println(resCode3.show) + println(resCode3.run) + println() + + val blasExprIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) + val resCode4: Expr[Array[Int] => Int] = '{ + arr => + if (arr.length != ~vec2.size.toExpr) throw new Exception("...") + ~{ + blasExprIntPVExpr.dot( + new Vec(vec2.size, i => Dyn('(arr(~i.toExpr)))), + vec2.map(i => Sta(i)) + ).expr + } + + } + println(resCode4.show) + println(resCode4.run.apply(arr1)) + println() + + import Complex.isLiftable + val blasExprComplexPVInt = new Blas1[Int, Complex[PV[Int]]](new RingComplex(new RingPV[Int](new RingInt, new RingIntExpr)), new StaticVecOps) + val resCode5: Expr[Array[Complex[Int]] => Complex[Int]] = '{ + arr => + if (arr.length != ~cmpxVec2.size.toExpr) throw new Exception("...") + ~{ + val cpx = blasExprComplexPVInt.dot( + new Vec(cmpxVec2.size, i => Complex(Dyn('(arr(~i.toExpr).re)), Dyn('(arr(~i.toExpr).im)))), + new Vec(cmpxVec2.size, i => Complex(Sta(cmpxVec2.get(i).re), Sta(cmpxVec2.get(i).im))) + ) + '(Complex(~cpx.re.expr, ~cpx.im.expr)) + } + } + println(resCode5.show) + println(resCode5.run.apply(cmpxArr1)) + println() + + val RingPVInt = new RingPV[Int](new RingInt, new RingIntExpr) + // Staged loop of dot product on vectors of Int or Expr[Int] + val dotIntOptExpr = new Blas1(RingPVInt, new StaticVecOps).dot + // will generate the code '{ ((arr: scala.Array[scala.Int]) => arr.apply(1).+(arr.apply(3))) } + val staticVec = Vec[Int, PV[Int]](5, i => Sta((i % 2))) + val code = '{(arr: Array[Int]) => ~dotIntOptExpr(Vec(5, i => Dyn('(arr(~i.toExpr)))), staticVec).expr } + println(code.show) + println() + } + +} diff --git a/tests/run-with-compiler/shonan-hmm.check b/tests/run-with-compiler/shonan-hmm.check new file mode 100644 index 000000000000..c4f1579e7497 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm.check @@ -0,0 +1,308 @@ +Complex(0,10) +Complex(1.*(4).-(2.*(2)), 1.*(2).+(2.*(4))) +List(Complex(2,0), Complex(-4,4), Complex(-2,6)) +((vout: scala.Array[Complex[scala.Int]], v1: scala.Array[Complex[scala.Int]], v2: scala.Array[Complex[scala.Int]]) => { + val n: scala.Int = vout.length + var i: scala.Int = 0 + while (i.<(n)) { + vout.update(i, Complex.apply[scala.Int](v1.apply(i).re.*(v2.apply(i).re).-(v1.apply(i).im.*(v2.apply(i).im)), v1.apply(i).re.*(v2.apply(i).im).+(v1.apply(i).im.*(v2.apply(i).re)))) + i = i.+(1) + } +}) +List(25, 30, 20, 43, 44) + + + +((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => { + val n: scala.Int = vout.length + val m: scala.Int = v.length + var i: scala.Int = 0 + while (i.<(n)) { + vout.update(i, { + var sum: scala.Int = 0 + var i$2: scala.Int = 0 + while (i$2.<(m)) { + sum = sum.+(v.apply(i$2).*(a.apply(i).apply(i$2))) + i$2 = i$2.+(1) + } + (sum: scala.Int) + }) + i = i.+(1) + } +}) + + + +((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => { + if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else () + if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else () + vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1)))) + vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1)))) + vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1)))) +}) + + + +{ + val arr: scala.Array[scala.Array[scala.Int]] = { + val array: scala.Array[scala.Array[scala.Int]] = dotty.runtime.Arrays.newGenericArray[scala.Array[scala.Int]](5)({ + scala.reflect.ClassTag.apply[scala.Array[scala.Int]](scala.Predef.classOf[scala.Array[scala.Int]]) + }) + array.update(0, { + val array$2: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$2.update(0, 5) + array$2.update(1, 0) + array$2.update(2, 0) + array$2.update(3, 5) + array$2.update(4, 0) + array$2 + }) + array.update(1, { + val array$3: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$3.update(0, 0) + array$3.update(1, 0) + array$3.update(2, 10) + array$3.update(3, 0) + array$3.update(4, 0) + array$3 + }) + array.update(2, { + val array$4: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$4.update(0, 0) + array$4.update(1, 10) + array$4.update(2, 0) + array$4.update(3, 0) + array$4.update(4, 0) + array$4 + }) + array.update(3, { + val array$5: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$5.update(0, 0) + array$5.update(1, 0) + array$5.update(2, 2) + array$5.update(3, 3) + array$5.update(4, 5) + array$5 + }) + array.update(4, { + val array$6: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$6.update(0, 0) + array$6.update(1, 0) + array$6.update(2, 3) + array$6.update(3, 0) + array$6.update(4, 7) + array$6 + }) + array + } + + ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { + if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0))) + vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0))) + vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0))) + vout.update(3, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(2)).+(v.apply(3).*(3)).+(v.apply(4).*(5))) + vout.update(4, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(3)).+(v.apply(3).*(0)).+(v.apply(4).*(7))) + }) +} + + + +{ + val arr: scala.Array[scala.Array[scala.Int]] = { + val array: scala.Array[scala.Array[scala.Int]] = dotty.runtime.Arrays.newGenericArray[scala.Array[scala.Int]](5)({ + scala.reflect.ClassTag.apply[scala.Array[scala.Int]](scala.Predef.classOf[scala.Array[scala.Int]]) + }) + array.update(0, { + val array$2: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$2.update(0, 5) + array$2.update(1, 0) + array$2.update(2, 0) + array$2.update(3, 5) + array$2.update(4, 0) + array$2 + }) + array.update(1, { + val array$3: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$3.update(0, 0) + array$3.update(1, 0) + array$3.update(2, 10) + array$3.update(3, 0) + array$3.update(4, 0) + array$3 + }) + array.update(2, { + val array$4: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$4.update(0, 0) + array$4.update(1, 10) + array$4.update(2, 0) + array$4.update(3, 0) + array$4.update(4, 0) + array$4 + }) + array.update(3, { + val array$5: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$5.update(0, 0) + array$5.update(1, 0) + array$5.update(2, 2) + array$5.update(3, 3) + array$5.update(4, 5) + array$5 + }) + array.update(4, { + val array$6: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$6.update(0, 0) + array$6.update(1, 0) + array$6.update(2, 3) + array$6.update(3, 0) + array$6.update(4, 7) + array$6 + }) + array + } + + ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { + if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) + vout.update(1, v.apply(2).*(10)) + vout.update(2, v.apply(1).*(10)) + vout.update(3, v.apply(2).*(2).+(v.apply(3).*(3)).+(v.apply(4).*(5))) + vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7))) + }) +} + + + +{ + val arr: scala.Array[scala.Array[scala.Int]] = { + val array: scala.Array[scala.Array[scala.Int]] = dotty.runtime.Arrays.newGenericArray[scala.Array[scala.Int]](5)({ + scala.reflect.ClassTag.apply[scala.Array[scala.Int]](scala.Predef.classOf[scala.Array[scala.Int]]) + }) + array.update(0, { + val array$2: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$2.update(0, 5) + array$2.update(1, 0) + array$2.update(2, 0) + array$2.update(3, 5) + array$2.update(4, 0) + array$2 + }) + array.update(1, { + val array$3: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$3.update(0, 0) + array$3.update(1, 0) + array$3.update(2, 10) + array$3.update(3, 0) + array$3.update(4, 0) + array$3 + }) + array.update(2, { + val array$4: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$4.update(0, 0) + array$4.update(1, 10) + array$4.update(2, 0) + array$4.update(3, 0) + array$4.update(4, 0) + array$4 + }) + array.update(3, { + val array$5: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$5.update(0, 0) + array$5.update(1, 0) + array$5.update(2, 2) + array$5.update(3, 3) + array$5.update(4, 5) + array$5 + }) + array.update(4, { + val array$6: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array$6.update(0, 0) + array$6.update(1, 0) + array$6.update(2, 3) + array$6.update(3, 0) + array$6.update(4, 7) + array$6 + }) + array + } + + ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { + if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) + vout.update(1, v.apply(2).*(10)) + vout.update(2, v.apply(1).*(10)) + vout.update(3, { + var sum: scala.Int = 0 + var i: scala.Int = 0 + while (i.<(5)) { + sum = sum.+(v.apply(i).*(arr.apply(3).apply(i))) + i = i.+(1) + } + (sum: scala.Int) + }) + vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7))) + }) +} + + + +((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { + if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) + vout.update(1, v.apply(2).*(10)) + vout.update(2, v.apply(1).*(10)) + vout.update(3, { + var sum: scala.Int = 0 + var i: scala.Int = 0 + while (i.<(5)) { + sum = sum.+(v.apply(i).*({ + val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array.update(0, 0) + array.update(1, 0) + array.update(2, 2) + array.update(3, 3) + array.update(4, 5) + array + }.apply(i))) + i = i.+(1) + } + (sum: scala.Int) + }) + vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7))) +}) + + + +{ + val row: scala.Array[scala.Int] = { + val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5) + array.update(0, 0) + array.update(1, 0) + array.update(2, 2) + array.update(3, 3) + array.update(4, 5) + array + } + + ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { + if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) + vout.update(1, v.apply(2).*(10)) + vout.update(2, v.apply(1).*(10)) + vout.update(3, { + var sum: scala.Int = 0 + var i: scala.Int = 0 + while (i.<(5)) { + sum = sum.+(v.apply(i).*(row.apply(i))) + i = i.+(1) + } + (sum: scala.Int) + }) + vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7))) + }) +} diff --git a/tests/run-with-compiler/shonan-hmm/Blas.scala b/tests/run-with-compiler/shonan-hmm/Blas.scala new file mode 100644 index 000000000000..4c5f873c3888 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Blas.scala @@ -0,0 +1,30 @@ + +import scala.quoted._ + +class Blas1[Idx, T, Unt](tring: Ring[T], vec: VecOp[Idx, Unt]) { + import tring._ + import vec._ + + implicit class Blas1VecOps(v1: Vec[Idx, T]) { + def `*.`(v2: Vec[Idx, T]): Vec[Idx, T] = v1.zipWith(v2, mul) + } + + implicit class Blas1OVecOps(vout: OVec[Idx, T, Unt]) { + def :=(vin: Vec[Idx, T]): Unt = iter(vout.vecAssign(vin)) + } + override def toString(): String = s"Blas1($tring, $vec)" +} + +class Blas2[Idx, T, Unt](tring: Ring[T], vec: VecROp[Idx, T, Unt]) extends Blas1[Idx, T, Unt](tring, vec) { + import tring._ + import vec._ + + implicit class Blas2VecOps(v1: Vec[Idx, T]) { + def dot(v2: Vec[Idx, T]): T = reduce(add, zero, v1 `*.` v2) + } + + implicit class Blas2MatOps(a: Vec[Idx, Vec[Idx, T]]) { + def *(v: Vec[Idx, T]): Vec[Idx, T] = a.vecMap(x => v dot x) + } + override def toString(): String = s"Blas2($tring, $vec)" +} diff --git a/tests/run-with-compiler/shonan-hmm/Complex.scala b/tests/run-with-compiler/shonan-hmm/Complex.scala new file mode 100644 index 000000000000..da3490b6f228 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Complex.scala @@ -0,0 +1,15 @@ + +import scala.quoted._ + +case class Complex[T](re: T, im: T) + +object Complex { + implicit def complexIsLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable { + def toExpr(c: Complex[T]): Expr[Complex[T]] = '{ Complex(~c.re.toExpr, ~c.im.toExpr) } + } + + def of_complex_expr(x: Expr[Complex[Int]]): Complex[Expr[Int]] = Complex('((~x).re), '((~x).im)) + def of_expr_complex(x: Complex[Expr[Int]]): Expr[Complex[Int]] = '(Complex(~x.re, ~x.im)) + + +} \ No newline at end of file diff --git a/tests/run-with-compiler/shonan-hmm/Lifters.scala b/tests/run-with-compiler/shonan-hmm/Lifters.scala new file mode 100644 index 000000000000..25b79d2428a4 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Lifters.scala @@ -0,0 +1,30 @@ + +import UnrolledExpr._ + +import scala.reflect.ClassTag +import scala.quoted._ + +object Lifters { + + implicit def ClassTagIsLiftable[T : Type](implicit ct: ClassTag[T]): Liftable[ClassTag[T]] = + ct => '(ClassTag(~ct.runtimeClass.toExpr)) + + implicit def ArrayIsLiftable[T : Type: ClassTag](implicit l: Liftable[T]): Liftable[Array[T]] = arr => '{ + val array = new Array[T](~arr.length.toExpr)(~implicitly[ClassTag[T]].toExpr) + ~initArray(arr, '(array)) + } + + implicit def IntArrayIsLiftable: Liftable[Array[Int]] = arr => '{ + val array = new Array[Int](~arr.length.toExpr) + ~initArray(arr, '(array)) + } + + private def initArray[T : Liftable](arr: Array[T], array: Expr[Array[T]]): Expr[Array[T]] = { + UnrolledExpr.block( + arr.zipWithIndex.map { + case (x, i) => '{ (~array)(~i.toExpr) = ~x.toExpr } + }.toList, + array) + } + +} diff --git a/tests/run-with-compiler/shonan-hmm/MVmult.scala b/tests/run-with-compiler/shonan-hmm/MVmult.scala new file mode 100644 index 000000000000..303e4f71b2bf --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/MVmult.scala @@ -0,0 +1,178 @@ + +import dotty.tools.dotc.quoted.Toolbox._ +import scala.quoted._ + +class MVmult[Idx, T, Unt](tring: Ring[T], vec: VecROp[Idx, T, Unt]) { + private[this] val blas2 = new Blas2(tring, vec) + import blas2._ + def mvmult(vout: OVec[Idx, T, Unt], a: Vec[Idx, Vec[Idx, T]], v: Vec[Idx, T]): Unt = vout := a * v + override def toString(): String = s"MVmult($tring, $vec)" +} + +object MVmult { + def mvmult_p(vout: Array[Int], a: Array[Array[Int]], v: Array[Int]): Unit = { + val n = vout.length + val m = v.length + + val vout_ = OVec(n, (i, x: Int) => vout(i) = x) + val a_ = Vec (n, i => Vec(m, j => a(i)(j))) + val v_ = Vec (n, i => v(i)) + + val MV = new MVmult[Int, Int, Unit](RingInt, new StaticVecR(RingInt)) + MV.mvmult(vout_, a_, v_) + } + + def mvmult_c: Expr[(Array[Int], Array[Array[Int]], Array[Int]) => Unit] = '{ + (vout, a, v) => { + val n = vout.length + val m = v.length + ~{ + val vout_ = OVec('(n), (i, x: Expr[Int]) => '(vout(~i) = ~x)) + val a_ = Vec('(n), (i: Expr[Int]) => Vec('(m), (j: Expr[Int]) => '{ a(~i)(~j) } )) + val v_ = Vec('(m), (i: Expr[Int]) => '(v(~i))) + + val MV = new MVmult[Expr[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRDyn) + MV.mvmult(vout_, a_, v_) + } + } + } + + def mvmult_mc(n: Int, m: Int): Expr[(Array[Int], Array[Array[Int]], Array[Int]) => Unit] = { + val MV = new MVmult[Int, Expr[Int], Expr[Unit]](RingIntExpr, new VecRStaDim(RingIntExpr)) + '{ + (vout, a, v) => { + if (~n.toExpr != vout.length) throw new IndexOutOfBoundsException(~n.toString.toExpr) + if (~m.toExpr != v.length) throw new IndexOutOfBoundsException(~m.toString.toExpr) + ~{ + val vout_ = OVec(n, (i, x: Expr[Int]) => '(vout(~i.toExpr) = ~x)) + val a_ = Vec(n, i => Vec(m, j => '{ a(~i.toExpr)(~j.toExpr) } )) + val v_ = Vec(m, i => '(v(~i.toExpr))) + + MV.mvmult(vout_, a_, v_) + } + } + } + } + + def mvmult_ac(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = { + import Lifters._ + '{ + val arr = ~a.toExpr + ~{ + val (n, m, a2) = amat1(a, '(arr)) + mvmult_abs0(new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))(n, m, a2) + } + } + } + + def mvmult_opt(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = { + import Lifters._ + '{ + val arr = ~a.toExpr + ~{ + val (n, m, a2) = amat1(a, '(arr)) + mvmult_abs0(new RingIntOPExpr, new VecRStaDyn(new RingIntPExpr))(n, m, a2) + } + } + } + + def mvmult_roll(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = { + import Lifters._ + '{ + val arr = ~a.toExpr + ~{ + val (n, m, a2) = amat1(a, '(arr)) + mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2) + } + } + } + + def mvmult_let1(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = { + val (n, m, a2) = amatCopy(a, copy_row1) + mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2) + } + + def mvmult_let(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = { + initRows(a) { rows => + val (n, m, a2) = amat2(a, rows) + mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2) + } + } + + def initRows[T](a: Array[Array[Int]])(cont: Array[Expr[Array[Int]]] => Expr[T]): Expr[T] = { + import Lifters._ + def loop(i: Int, acc: List[Expr[Array[Int]]]): Expr[T] = { + if (i >= a.length) cont(acc.toArray.reverse) + else if (a(i).count(_ != 0) < VecRStaOptDynInt.threshold) { + val default: Expr[Array[Int]] = '(null.asInstanceOf[Array[Int]]) // never accessed + loop(i + 1, default :: acc) + } else '{ + val row = ~a(i).toExpr + ~{ loop(i + 1, '(row) :: acc) } + } + } + loop(0, Nil) + } + + def amat1(a: Array[Array[Int]], aa: Expr[Array[Array[Int]]]): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = { + val n = a.length + val m = a(0).length + val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match { + case (Sta(i), Sta(j)) => Sta(a(i)(j)) + case (Sta(i), Dyn(j)) => Dyn('((~aa)(~i.toExpr)(~j))) + case (i, j) => Dyn('{ (~aa)(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) }) + })) + (n, m, vec) + } + + def amat2(a: Array[Array[Int]], refs: Array[Expr[Array[Int]]]): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = { + val n = a.length + val m = a(0).length + val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match { + case (Sta(i), Sta(j)) => Sta(a(i)(j)) + case (Sta(i), Dyn(j)) => Dyn('((~refs(i))(~j))) + })) + (n, m, vec) + } + + def amatCopy(a: Array[Array[Int]], copyRow: Array[Int] => (Expr[Int] => Expr[Int])): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = { + val n = a.length + val m = a(0).length + val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match { + case (Sta(i), Sta(j)) => Sta(a(i)(j)) + case (Sta(i), Dyn(j)) => + val defrec = copyRow(a(i)) + Dyn(defrec(j)) + case (i, j) => ??? + })) + (n, m, vec) + } + + def copy_row1: Array[Int] => (Expr[Int] => Expr[Int]) = v => { + import Lifters._ + val arr = v.toExpr + i => '{ (~arr).apply(~i) } + } + + def copy_row_let: Array[Int] => (Expr[Int] => Expr[Int]) = v => { + import Lifters._ + val arr: Expr[Array[Int]] = ??? // FIXME used genlet v.toExpr + i => '{ (~arr).apply(~i) } + } + + private def mvmult_abs0(ring: Ring[PV[Int]], vecOp: VecROp[PV[Int], PV[Int], Expr[Unit]])(n: Int, m: Int, a: Vec[PV[Int], Vec[PV[Int], PV[Int]]]): Expr[(Array[Int], Array[Int]) => Unit] = { + '{ + (vout, v) => { + if (~n.toExpr != vout.length) throw new IndexOutOfBoundsException(~n.toString.toExpr) + if (~m.toExpr != v.length) throw new IndexOutOfBoundsException(~m.toString.toExpr) + ~{ + val vout_ : OVec[PV[Int], PV[Int], Expr[Unit]] = OVec(Sta(n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x))) + val v_ : Vec[PV[Int], PV[Int]] = Vec(Sta(m), i => Dyn('(v(~Dyns.dyni(i))))) + val MV = new MVmult[PV[Int], PV[Int], Expr[Unit]](ring, vecOp) + MV.mvmult(vout_, a, v_) + } + } + } + } + +} diff --git a/tests/run-with-compiler/shonan-hmm/PV.scala b/tests/run-with-compiler/shonan-hmm/PV.scala new file mode 100644 index 000000000000..d33cb7a45333 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/PV.scala @@ -0,0 +1,16 @@ + +import scala.quoted._ + +sealed trait PV[T] + +case class Sta[T](x: T) extends PV[T] + +case class Dyn[T](x: Expr[T]) extends PV[T] + +object Dyns { + def dyn[T: Liftable](pv: PV[T]): Expr[T] = pv match { + case Sta(x) => x.toExpr + case Dyn(x) => x + } + val dyni: PV[Int] => Expr[Int] = dyn[Int] +} diff --git a/tests/run-with-compiler/shonan-hmm/Ring.scala b/tests/run-with-compiler/shonan-hmm/Ring.scala new file mode 100644 index 000000000000..521fff3ac819 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Ring.scala @@ -0,0 +1,89 @@ + +import scala.quoted._ + +trait Ring[T] { + def zero: T + val one: T + def add: (x: T, y: T) => T + def sub: (x: T, y: T) => T + def mul: (x: T, y: T) => T + + implicit class Ops(x: T) { + def +(y: T): T = add(x, y) + def -(y: T): T = sub(x, y) + def *(y: T): T = mul(x, y) + } +} + +object RingInt extends Ring[Int] { + val zero = 0 + val one = 0 + val add = (x, y) => x + y + val sub = (x, y) => x - y + val mul = (x, y) => x * y + override def toString(): String = "RingInt" +} + +object RingIntExpr extends Ring[Expr[Int]] { + val zero = '(0) + val one = '(1) + val add = (x, y) => '(~x + ~y) + val sub = (x, y) => '(~x - ~y) + val mul = (x, y) => '(~x * ~y) + override def toString(): String = "RingIntExpr" +} + +case class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] { + import u._ + val zero = Complex(u.zero, u.zero) + val one = Complex(u.one, u.zero) + val add = (x, y) => Complex(x.re + y.re, x.im + y.im) + val sub = (x, y) => Complex(x.re + y.re, x.im + y.im) + val mul = (x, y) => Complex(x.re * y.re - x.im * y.im, x.re * y.im + x.im * y.re) + override def toString(): String = s"RingComplex($u)" +} + +case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends Ring[PV[U]] { + type T = PV[U] + + val dyn = Dyns.dyn[U] + import staRing._ + import dynRing._ + + val zero: T = Sta(staRing.zero) + val one: T = Sta(staRing.one) + def add = (x: T, y: T) => (x, y) match { + case (Sta(x), Sta(y)) => Sta(x + y) + case (x, y) => Dyn(dyn(x) + dyn(y)) + } + def sub = (x: T, y: T) => (x, y) match { + case (Sta(x), Sta(y)) => Sta(x - y) + case (x, y) => Dyn(dyn(x) - dyn(y)) + } + def mul = (x: T, y: T) => (x, y) match { + case (Sta(x), Sta(y)) => Sta(x * y) + case (x, y) => Dyn(dyn(x) * dyn(y)) + } +} + +class RingIntPExpr extends RingPV(RingInt, RingIntExpr) + +class RingIntOPExpr extends RingIntPExpr { + override def add = (x: PV[Int], y: PV[Int]) => (x, y) match { + case (Sta(0), y) => y + case (x, Sta(0)) => x + case (x, y) => super.add(x, y) + } + override def sub = (x: T, y: T) => (x, y) match { + case (Sta(0), y) => y + case (x, Sta(0)) => x + case (x, y) => super.sub(x, y) + } + override def mul = (x: T, y: T) => (x, y) match { + case (Sta(0), y) => Sta(0) + case (x, Sta(0)) => Sta(0) + case (Sta(1), y) => y + case (x, Sta(1)) => x + case (x, y) => super.mul(x, y) + } +} diff --git a/tests/run-with-compiler/shonan-hmm/Test.scala b/tests/run-with-compiler/shonan-hmm/Test.scala new file mode 100644 index 000000000000..947836669602 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Test.scala @@ -0,0 +1,92 @@ + +import dotty.tools.dotc.quoted.Toolbox._ +import scala.quoted._ + +// DYNAMIC + +object Test { + + def main(args: Array[String]): Unit = { + { + val intComplex = new RingComplex(RingInt) + import intComplex._ + + println(Complex(1, 2) * Complex(4, 2)) + } + + { + val intExprComplex = new RingComplex(RingIntExpr) + import intExprComplex._ + + val res = Complex('(1), '(2)) * Complex('(4), '(2)) + println(s"Complex(${res.re.show}, ${res.im.show})") + } + + // { + // val intExprComplex = implicitly[Ring[Expr[Complex[Int]]]] + // import intExprComplex._ + + // val res = '(Complex(1, 2)) * '(Complex(4, 2)) + // println(res.show) + // } + + val arr1 = Array(Complex(1, 0), Complex(0, 4), Complex(2, 2)) + val arr2 = Array(Complex(2, 0), Complex(1, 1), Complex(1, 2)) + val out = Array(Complex(0, 0), Complex(0, 0), Complex(0, 0)) + Vmults.vmult(out, arr1, arr2) + println(out.toList) + + println(Vmults.vmultCA.show) + + val a = Array( + Array( 5, 0, 0, 5, 0), + Array( 0, 0, 10, 0, 0), + Array( 0, 10, 0, 0, 0), + Array( 0, 0, 2, 3, 5), + Array( 0, 0, 3, 0, 7) + ) + + val v1 = Array(1, 2, 3, 4, 5) + val v1out = Array(0, 0, 0, 0, 0) + MVmult.mvmult_p(v1out, a, v1) + println(v1out.toList) + println() + println() + println() + + println(MVmult.mvmult_c.show) + println() + println() + println() + + println(MVmult.mvmult_mc(3, 2).show) + println() + println() + println() + + println(MVmult.mvmult_ac(a).show) + println() + println() + println() + + println(MVmult.mvmult_opt(a).show) + println() + println() + println() + + println(MVmult.mvmult_roll(a).show) + println() + println() + println() + + println(MVmult.mvmult_let1(a).show) + println() + println() + println() + + println(MVmult.mvmult_let(a).show) + } +} + + + diff --git a/tests/run-with-compiler/shonan-hmm/UnrolledExpr.scala b/tests/run-with-compiler/shonan-hmm/UnrolledExpr.scala new file mode 100644 index 000000000000..72877d2ad483 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/UnrolledExpr.scala @@ -0,0 +1,30 @@ +import scala.quoted._ +import Lifters._ + +object UnrolledExpr { + + implicit class Unrolled[T: Liftable, It <: Iterable[T]](xs: It) { + def unrolled: UnrolledExpr[T, It] = new UnrolledExpr(xs) + } + + // TODO support blocks in the compiler to avoid creating trees of blocks? + def block[T](stats: Iterable[Expr[_]], expr: Expr[T]): Expr[T] = { + def rec(stats: List[Expr[_]]): Expr[T] = stats match { + case x :: xs => '{ ~x; ~rec(xs) } + case Nil => expr + } + rec(stats.toList) + } + +} + +class UnrolledExpr[T: Liftable, It <: Iterable[T]](xs: It) { + import UnrolledExpr._ + + def foreach[U](f: T => Expr[U]): Expr[Unit] = block(xs.map(f), '()) + + def withFilter(f: T => Boolean): UnrolledExpr[T, Iterable[T]] = new UnrolledExpr(xs.filter(f)) + + def foldLeft[U](acc: Expr[U])(f: (Expr[U], T) => Expr[U]): Expr[U] = + xs.foldLeft(acc)((acc, x) => f(acc, x)) +} diff --git a/tests/run-with-compiler/shonan-hmm/Vec.scala b/tests/run-with-compiler/shonan-hmm/Vec.scala new file mode 100644 index 000000000000..68ad9f43f8b6 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Vec.scala @@ -0,0 +1,20 @@ +import scala.quoted._ + +case class Vec[Idx, T](size: Idx, get: Idx => T) { + def apply(idx: Idx): T = get(idx) + + def vecMap[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i))) + + def zipWith[U, V](vec2: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = + Vec(size, i => f(get(i), vec2(i))) +} + +case class OVec[Idx, T, Unt](size: Idx, update: (Idx, T) => Unt) { + def vecAssign(vecIn: Vec[Idx, T]): Vec[Idx, Unt] = + Vec(vecIn.size, i => update(i, vecIn(i))) +} + +object Vec { + def fromArray[T](a: Array[T]): (Vec[Int, T], OVec[Int, T, Unit]) = + (Vec(a.size, i => a(i)), OVec(a.size, (i, v) => a(i) = v)) +} diff --git a/tests/run-with-compiler/shonan-hmm/VecOp.scala b/tests/run-with-compiler/shonan-hmm/VecOp.scala new file mode 100644 index 000000000000..c8bbbd6d3b69 --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/VecOp.scala @@ -0,0 +1,24 @@ +import scala.quoted._ + +trait VecOp[Idx, Unt] { + def iter: Vec[Idx, Unt] => Unt +} + +class VecSta extends VecOp[Int, Unit] { + def iter: Vec[Int, Unit] => Unit = { arr => + for (i <- 0 until arr.size) + arr(i) + } + override def toString(): String = s"StaticVec" +} + +class VecDyn extends VecOp[Expr[Int], Expr[Unit]] { + def iter: Vec[Expr[Int], Expr[Unit]] => Expr[Unit] = arr => '{ + var i = 0 + while (i < ~arr.size) { + ~arr('(i)) + i += 1 + } + } + override def toString(): String = s"DynVec" +} diff --git a/tests/run-with-compiler/shonan-hmm/VecROp.scala b/tests/run-with-compiler/shonan-hmm/VecROp.scala new file mode 100644 index 000000000000..3d8b8aa63ebe --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/VecROp.scala @@ -0,0 +1,93 @@ + +import scala.quoted._ + +trait VecROp[Idx, T, Unt] extends VecOp[Idx, Unt] { + def reduce: ((T, T) => T, T, Vec[Idx, T]) => T +} + +class StaticVecR[T](r: Ring[T]) extends VecSta with VecROp[Int, T, Unit] { + import r._ + def reduce: ((T, T) => T, T, Vec[Int, T]) => T = { (plus, zero, vec) => + var sum = zero + for (i <- 0 until vec.size) + sum = plus(sum, vec(i)) + sum + } + override def toString(): String = s"StaticVecR($r)" +} + +class VecRDyn[T: Type] extends VecDyn with VecROp[Expr[Int], Expr[T], Expr[Unit]] { + def reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = { + (plus, zero, vec) => '{ + var sum = ~zero + var i = 0 + while (i < ~vec.size) { + sum = ~{ plus('(sum), vec('(i))) } + i += 1 + } + sum + } + } + override def toString(): String = s"VecRDyn" +} + +class VecRStaDim[T: Type](r: Ring[T]) extends VecROp[Int, T, Expr[Unit]] { + val M = new StaticVecR[T](r) + def reduce: ((T, T) => T, T, Vec[Int, T]) => T = M.reduce + val seq: (Expr[Unit], Expr[Unit]) => Expr[Unit] = (e1, e2) => '{ ~e1; ~e2 } + // val iter: (arr: Vec[]) = reduce seq .<()>. arr + def iter: Vec[Int, Expr[Unit]] => Expr[Unit] = arr => { + def loop(i: Int, acc: Expr[Unit]): Expr[Unit] = + if (i < arr.size) loop(i + 1, '{ ~acc; ~arr.get(i) }) + else acc + loop(0, '()) + } + override def toString(): String = s"VecRStaDim($r)" +} + +class VecRStaDyn[T : Type : Liftable](r: Ring[PV[T]]) extends VecROp[PV[Int], PV[T], Expr[Unit]] { + val VSta: VecROp[Int, PV[T], Expr[Unit]] = new VecRStaDim(r) + val VDyn = new VecRDyn + val dyn = Dyns.dyn[T] + def reduce: ((PV[T], PV[T]) => PV[T], PV[T], Vec[PV[Int], PV[T]]) => PV[T] = { (plus, zero, vec) => vec match { + case Vec(Sta(n), v) => VSta.reduce(plus, zero, Vec(n, i => v(Sta(i)))) + case Vec(Dyn(n), v) => Dyn(VDyn.reduce((x, y) => dyn(plus(Dyn(x), Dyn(y))), dyn(zero), Vec(n, i => dyn(v(Dyn(i)))))) + } + } + def iter: Vec[PV[Int], Expr[Unit]] => Expr[Unit] = arr => { + arr.size match { + case Sta(n) => + def loop(i: Int, acc: Expr[Unit]): Expr[Unit] = + if (i < n) loop(i + 1, '{ ~acc; ~arr.get(Sta(i)) }) + else acc + loop(0, '()) + case Dyn(n) => + '{ "TODO"; () } + + } + } + override def toString(): String = s"VecRStaDim($r)" +} + +object VecRStaOptDynInt { + val threshold = 3 +} + +class VecRStaOptDynInt(r: Ring[PV[Int]]) extends VecRStaDyn(r) { + val M: VecROp[PV[Int], PV[Int], Expr[Unit]] = new VecRStaDyn(r) + + override def reduce: ((PV[Int], PV[Int]) => PV[Int], PV[Int], Vec[PV[Int], PV[Int]]) => PV[Int] = (plus, zero, vec) => vec match { + case Vec(Sta(n), vecf) => + if (count_non_zeros(n, vecf) < VecRStaOptDynInt.threshold) M.reduce(plus, zero, vec) + else M.reduce(plus, zero, Vec(Dyn(n.toExpr), vecf)) + case _ => M.reduce(plus, zero, vec) + } + + private def count_non_zeros(n: Int, vecf: PV[Int] => PV[Int]): Int = { + def loop(i: Int, acc: Int): Int = { + if (i >= n) acc + else loop(i + 1, if (vecf(Sta(i)) == Sta(0)) acc else acc + 1) + } + loop(0, 0) + } +} diff --git a/tests/run-with-compiler/shonan-hmm/Vmults.scala b/tests/run-with-compiler/shonan-hmm/Vmults.scala new file mode 100644 index 000000000000..6d2755503d6e --- /dev/null +++ b/tests/run-with-compiler/shonan-hmm/Vmults.scala @@ -0,0 +1,37 @@ + +import dotty.tools.dotc.quoted.Toolbox._ +import scala.quoted._ + +class Vmult[Idx, T, Unt](tring: Ring[T], vec: VecOp[Idx, Unt]) { + private[this] val blas = new Blas1(tring, vec) + import blas._ + def vmult(vout: OVec[Idx, T, Unt], v1: Vec[Idx, T], v2: Vec[Idx, T]): Unt = vout := v1 `*.` v2 + override def toString(): String = s"Vmult($tring, $vec)" +} + +object Vmults { + def vmult(vout: Array[Complex[Int]], v1: Array[Complex[Int]], v2: Array[Complex[Int]]): Unit = { + val n = vout.length + + val vout_ = OVec(n, (i, v: Complex[Int]) => vout(i) = v) + val v1_ = Vec (n, i => v1(i)) + val v2_ = Vec (n, i => v2(i)) + + val V = new Vmult[Int, Complex[Int], Unit](RingComplex(RingInt), new VecSta) + V.vmult(vout_, v1_, v2_) + } + + def vmultCA: Expr[(Array[Complex[Int]], Array[Complex[Int]], Array[Complex[Int]]) => Unit] = '{ + (vout, v1, v2) => { + val n = vout.length + ~{ + val vout_ = OVec[Expr[Int], Complex[Expr[Int]], Expr[Unit]]('(n), (i, v) => '(vout(~i) = ~Complex.of_expr_complex(v))) + val v1_ = Vec ('(n), i => Complex.of_complex_expr('(v1(~i)))) + val v2_ = Vec ('(n), i => Complex.of_complex_expr('(v2(~i)))) + + val V = new Vmult[Expr[Int], Complex[Expr[Int]], Expr[Unit]](RingComplex(RingIntExpr), new VecDyn) + V.vmult(vout_, v1_, v2_) + } + } + } +}