Skip to content

Commit 8642863

Browse files
Merge pull request #4726 from dotty-staging/shonan-hmm
Implement Shonan HMM
2 parents d1a184c + 51e9276 commit 8642863

16 files changed

+1219
-0
lines changed

library/src/scala/tasty/util/ShowSourceCode.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
369369
expr match {
370370
case Term.Lambda(_, _) =>
371371
// Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)}
372+
assert(stats.size == 1)
372373
val DefDef(_, _, args :: Nil, _, Some(rhs)) :: Nil = stats
373374
inParens {
374375
printArgsDefs(args)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
10
2+
3+
Complex(4,3)
4+
5+
0.+(0.*(1)).+(1.*(0)).+(2.*(1)).+(4.*(0)).+(8.*(1))
6+
10
7+
8+
((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => {
9+
if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else ()
10+
var sum: scala.Int = 0
11+
var i: scala.Int = 0
12+
while (i.<(scala.Predef.intArrayOps(arr1).size)) {
13+
sum = sum.+(arr1.apply(i).*(arr2.apply(i)))
14+
i = i.+(1)
15+
}
16+
(sum: scala.Int)
17+
})
18+
10
19+
20+
0.+(2).+(8)
21+
10
22+
23+
((arr: scala.Array[scala.Int]) => {
24+
if (arr.length.!=(5)) throw new scala.Exception("...") else ()
25+
arr.apply(0).+(arr.apply(2)).+(arr.apply(4))
26+
})
27+
10
28+
29+
((arr: scala.Array[Complex[scala.Int]]) => {
30+
if (arr.length.!=(4)) throw new scala.Exception("...") else ()
31+
Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2)))
32+
})
33+
Complex(4,3)
34+
35+
((arr: scala.Array[scala.Int]) => arr.apply(1).+(arr.apply(3)))
36+
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import scala.quoted._
2+
3+
trait Ring[T] {
4+
val zero: T
5+
val one: T
6+
val add: (x: T, y: T) => T
7+
val sub: (x: T, y: T) => T
8+
val mul: (x: T, y: T) => T
9+
}
10+
11+
class RingInt extends Ring[Int] {
12+
val zero = 0
13+
val one = 1
14+
val add = (x, y) => x + y
15+
val sub = (x, y) => x - y
16+
val mul = (x, y) => x * y
17+
}
18+
19+
class RingIntExpr extends Ring[Expr[Int]] {
20+
val zero = '(0)
21+
val one = '(1)
22+
val add = (x, y) => '(~x + ~y)
23+
val sub = (x, y) => '(~x - ~y)
24+
val mul = (x, y) => '(~x * ~y)
25+
}
26+
27+
class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] {
28+
val zero = Complex(u.zero, u.zero)
29+
val one = Complex(u.one, u.zero)
30+
val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im))
31+
val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im))
32+
val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
33+
}
34+
35+
sealed trait PV[T] {
36+
def expr(implicit l: Liftable[T]): Expr[T]
37+
}
38+
case class Sta[T](x: T) extends PV[T] {
39+
def expr(implicit l: Liftable[T]): Expr[T] = x.toExpr
40+
}
41+
case class Dyn[T](x: Expr[T]) extends PV[T] {
42+
def expr(implicit l: Liftable[T]): Expr[T] = x
43+
}
44+
45+
class RingPV[U: Liftable](u: Ring[U], eu: Ring[Expr[U]]) extends Ring[PV[U]] {
46+
val zero: PV[U] = Sta(u.zero)
47+
val one: PV[U] = Sta(u.one)
48+
val add = (x: PV[U], y: PV[U]) => (x, y) match {
49+
case (Sta(u.zero), x) => x
50+
case (x, Sta(u.zero)) => x
51+
case (Sta(x), Sta(y)) => Sta(u.add(x, y))
52+
case (x, y) => Dyn(eu.add(x.expr, y.expr))
53+
}
54+
val sub = (x: PV[U], y: PV[U]) => (x, y) match {
55+
case (x, Sta(u.zero)) => x
56+
case (Sta(x), Sta(y)) => Sta(u.sub(x, y))
57+
case (x, y) => Dyn(eu.sub(x.expr, y.expr))
58+
}
59+
val mul = (x: PV[U], y: PV[U]) => (x, y) match {
60+
case (Sta(u.zero), _) => Sta(u.zero)
61+
case (_, Sta(u.zero)) => Sta(u.zero)
62+
case (Sta(u.one), x) => x
63+
case (x, Sta(u.one)) => x
64+
case (Sta(x), Sta(y)) => Sta(u.mul(x, y))
65+
case (x, y) => Dyn(eu.mul(x.expr, y.expr))
66+
}
67+
}
68+
69+
70+
case class Complex[T](re: T, im: T)
71+
72+
object Complex {
73+
implicit def isLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable[Complex[T]] {
74+
def toExpr(comp: Complex[T]): Expr[Complex[T]] = '(Complex(~comp.re.toExpr, ~comp.im.toExpr))
75+
}
76+
}
77+
78+
case class Vec[Idx, T](size: Idx, get: Idx => T) {
79+
def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i)))
80+
def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i)))
81+
}
82+
83+
object Vec {
84+
def from[T](elems: T*): Vec[Int, T] = new Vec(elems.size, i => elems(i))
85+
}
86+
87+
trait VecOps[Idx, T] {
88+
val reduce: ((T, T) => T, T, Vec[Idx, T]) => T
89+
}
90+
91+
class StaticVecOps[T] extends VecOps[Int, T] {
92+
val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => {
93+
var sum = zero
94+
for (i <- 0 until vec.size)
95+
sum = plus(sum, vec.get(i))
96+
sum
97+
}
98+
}
99+
100+
class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] {
101+
val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{
102+
var sum = ~zero
103+
var i = 0
104+
while (i < ~vec.size) {
105+
sum = ~{ plus('(sum), vec.get('(i))) }
106+
i += 1
107+
}
108+
sum
109+
}
110+
}
111+
112+
class Blas1[Idx, T](r: Ring[T], ops: VecOps[Idx, T]) {
113+
def dot(v1: Vec[Idx, T], v2: Vec[Idx, T]): T = ops.reduce(r.add, r.zero, v1.zipWith(v2, r.mul))
114+
}
115+
116+
object Test {
117+
118+
implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make
119+
120+
def main(args: Array[String]): Unit = {
121+
val arr1 = Array(0, 1, 2, 4, 8)
122+
val arr2 = Array(1, 0, 1, 0, 1)
123+
val cmpxArr1 = Array(Complex(1, 0), Complex(2, 3), Complex(0, 2), Complex(3, 1))
124+
val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 1), Complex(2, 0))
125+
126+
val vec1 = new Vec(arr1.size, i => arr1(i))
127+
val vec2 = new Vec(arr2.size, i => arr2(i))
128+
val cmpxVec1 = new Vec(cmpxArr1.size, i => cmpxArr1(i))
129+
val cmpxVec2 = new Vec(cmpxArr2.size, i => cmpxArr2(i))
130+
131+
val blasInt = new Blas1(new RingInt, new StaticVecOps)
132+
val res1 = blasInt.dot(vec1, vec2)
133+
println(res1)
134+
println()
135+
136+
val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps)
137+
val res2 = blasComplexInt.dot(
138+
cmpxVec1,
139+
cmpxVec2
140+
)
141+
println(res2)
142+
println()
143+
144+
val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps)
145+
val resCode1 = blasStaticIntExpr.dot(
146+
vec1.map(_.toExpr),
147+
vec2.map(_.toExpr)
148+
)
149+
println(resCode1.show)
150+
println(resCode1.run)
151+
println()
152+
153+
val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps)
154+
val resCode2: Expr[(Array[Int], Array[Int]) => Int] = '{
155+
(arr1, arr2) =>
156+
if (arr1.length != arr2.length) throw new Exception("...")
157+
~{
158+
blasExprIntExpr.dot(
159+
new Vec('(arr1.size), i => '(arr1(~i))),
160+
new Vec('(arr2.size), i => '(arr2(~i)))
161+
)
162+
}
163+
}
164+
println(resCode2.show)
165+
println(resCode2.run.apply(arr1, arr2))
166+
println()
167+
168+
val blasStaticIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
169+
val resCode3 = blasStaticIntPVExpr.dot(
170+
vec1.map(i => Dyn(i.toExpr)),
171+
vec2.map(i => Sta(i))
172+
).expr
173+
println(resCode3.show)
174+
println(resCode3.run)
175+
println()
176+
177+
val blasExprIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
178+
val resCode4: Expr[Array[Int] => Int] = '{
179+
arr =>
180+
if (arr.length != ~vec2.size.toExpr) throw new Exception("...")
181+
~{
182+
blasExprIntPVExpr.dot(
183+
new Vec(vec2.size, i => Dyn('(arr(~i.toExpr)))),
184+
vec2.map(i => Sta(i))
185+
).expr
186+
}
187+
188+
}
189+
println(resCode4.show)
190+
println(resCode4.run.apply(arr1))
191+
println()
192+
193+
import Complex.isLiftable
194+
val blasExprComplexPVInt = new Blas1[Int, Complex[PV[Int]]](new RingComplex(new RingPV[Int](new RingInt, new RingIntExpr)), new StaticVecOps)
195+
val resCode5: Expr[Array[Complex[Int]] => Complex[Int]] = '{
196+
arr =>
197+
if (arr.length != ~cmpxVec2.size.toExpr) throw new Exception("...")
198+
~{
199+
val cpx = blasExprComplexPVInt.dot(
200+
new Vec(cmpxVec2.size, i => Complex(Dyn('(arr(~i.toExpr).re)), Dyn('(arr(~i.toExpr).im)))),
201+
new Vec(cmpxVec2.size, i => Complex(Sta(cmpxVec2.get(i).re), Sta(cmpxVec2.get(i).im)))
202+
)
203+
'(Complex(~cpx.re.expr, ~cpx.im.expr))
204+
}
205+
}
206+
println(resCode5.show)
207+
println(resCode5.run.apply(cmpxArr1))
208+
println()
209+
210+
val RingPVInt = new RingPV[Int](new RingInt, new RingIntExpr)
211+
// Staged loop of dot product on vectors of Int or Expr[Int]
212+
val dotIntOptExpr = new Blas1(RingPVInt, new StaticVecOps).dot
213+
// will generate the code '{ ((arr: scala.Array[scala.Int]) => arr.apply(1).+(arr.apply(3))) }
214+
val staticVec = Vec[Int, PV[Int]](5, i => Sta((i % 2)))
215+
val code = '{(arr: Array[Int]) => ~dotIntOptExpr(Vec(5, i => Dyn('(arr(~i.toExpr)))), staticVec).expr }
216+
println(code.show)
217+
println()
218+
}
219+
220+
}

0 commit comments

Comments
 (0)