Skip to content

Commit 38808d2

Browse files
committed
wip
1 parent 2211d5b commit 38808d2

File tree

5 files changed

+175
-20
lines changed

5 files changed

+175
-20
lines changed

tests/run-with-compiler/shonan-hmm.check

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,69 @@ List(25, 30, 20, 43, 44)
9494
vout.update(4, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(3)).+(v.apply(3).*(0)).+(v.apply(4).*(7)))
9595
})
9696
}
97+
98+
99+
100+
{
101+
val arr: scala.Array[scala.Array[scala.Int]] = {
102+
val array: scala.Array[scala.Array[scala.Int]] = dotty.runtime.Arrays.newGenericArray[scala.Array[scala.Int]](5)({
103+
scala.reflect.ClassTag.apply[scala.Array[scala.Int]](scala.Predef.classOf[scala.Array[scala.Int]])
104+
})
105+
array.update(0, {
106+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
107+
array.update(0, 5)
108+
array.update(1, 0)
109+
array.update(2, 0)
110+
array.update(3, 5)
111+
array.update(4, 0)
112+
array
113+
})
114+
array.update(1, {
115+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
116+
array.update(0, 0)
117+
array.update(1, 0)
118+
array.update(2, 10)
119+
array.update(3, 0)
120+
array.update(4, 0)
121+
array
122+
})
123+
array.update(2, {
124+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
125+
array.update(0, 0)
126+
array.update(1, 10)
127+
array.update(2, 0)
128+
array.update(3, 0)
129+
array.update(4, 0)
130+
array
131+
})
132+
array.update(3, {
133+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
134+
array.update(0, 0)
135+
array.update(1, 0)
136+
array.update(2, 2)
137+
array.update(3, 3)
138+
array.update(4, 5)
139+
array
140+
})
141+
array.update(4, {
142+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
143+
array.update(0, 0)
144+
array.update(1, 0)
145+
array.update(2, 3)
146+
array.update(3, 0)
147+
array.update(4, 7)
148+
array
149+
})
150+
array
151+
}
152+
153+
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
154+
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
155+
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
156+
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
157+
vout.update(1, v.apply(2).*(10))
158+
vout.update(2, v.apply(1).*(10))
159+
vout.update(3, v.apply(2).*(2).+(v.apply(3).*(3)).+(v.apply(4).*(5)))
160+
vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7)))
161+
})
162+
}

tests/run-with-compiler/shonan-hmm/MVmult.scala

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,41 +66,59 @@ object MVmult {
6666
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
6767
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
6868
}))
69-
mvmult_abs(a.length, a(0).length, a2)
69+
mvmult_abs0(new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))(a.length, a(0).length, a2)
7070
}
7171
}
7272
}
7373

74-
private def mvmult_abs(n: Int, m: Int, a: Vec[PV[Int], Vec[PV[Int], PV[Int]]]): Expr[(Array[Int], Array[Int]) => Unit] = {
74+
def mvmult_opt(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
75+
val n = a.length
76+
val m = a(0).length
77+
import Lifters._
78+
'{
79+
val arr = ~a.toExpr
80+
~{
81+
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
82+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
83+
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
84+
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
85+
}))
86+
mvmult_abs0(new RingIntOPExpr, new VecRStaDyn(new RingIntPExpr))(a.length, a(0).length, a2)
87+
}
88+
}
89+
}
90+
91+
def mvmult_roll(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
92+
val n = a.length
93+
val m = a(0).length
94+
import Lifters._
95+
'{
96+
val arr = ~a.toExpr
97+
~{
98+
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
99+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
100+
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
101+
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
102+
}))
103+
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(a.length, a(0).length, a2)
104+
}
105+
}
106+
}
107+
108+
private def mvmult_abs0(ring: Ring[PV[Int]], vecOp: VecROp[PV[Int], PV[Int], Expr[Unit]])(n: Int, m: Int, a: Vec[PV[Int], Vec[PV[Int], PV[Int]]]): Expr[(Array[Int], Array[Int]) => Unit] = {
75109
'{
76110
(vout, v) => {
77111
if (~n.toExpr != vout.length) throw new IndexOutOfBoundsException(~n.toString.toExpr)
78112
if (~m.toExpr != v.length) throw new IndexOutOfBoundsException(~m.toString.toExpr)
79113
~{
80114
val vout_ : OVec[PV[Int], PV[Int], Expr[Unit]] = OVec(Sta(n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x)))
81115
val v_ : Vec[PV[Int], PV[Int]] = Vec(Sta(m), i => Dyn('(v(~Dyns.dyni(i)))))
82-
val MV = new MVmult[PV[Int], PV[Int], Expr[Unit]](new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))
116+
val MV = new MVmult[PV[Int], PV[Int], Expr[Unit]](ring, vecOp)
83117
MV.mvmult(vout_, a, v_)
84118
}
85119
}
86120
}
87121
}
88122

89123

90-
91-
// let mvmult_abs : _ →
92-
// amat → (float array → float array → unit) code =
93-
// fun mvmult → fun {n;m;a} →
94-
// .<fun vout v →
95-
// assert (n = Array.length vout && m = Array.length v);
96-
// .~(let vout = OVec (Sta n, fun i v → .<vout.(.~(dyni i)) ← .~(dynf v)>.) in
97-
// let v = Vec (Sta m, fun j → Dyn .<v.(.~(dyni j))>.) in
98-
// mvmult vout a v)
99-
// >.
100-
// val mvmult_abs :
101-
// ((int pv, float pv, unit code) Vector.ovec →
102-
// (int pv, (int pv, float pv) Vector.vec) Vector.vec →
103-
// (int pv, float pv) Vector.vec → unit code) →
104-
// amat → (float array → float array → unit) code = <fun>
105-
106124
}

tests/run-with-compiler/shonan-hmm/Ring.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends
6868

6969
class RingIntPExpr extends RingPV(RingInt, RingIntExpr)
7070

71-
class RingIntOPCode extends RingIntPExpr {
71+
class RingIntOPExpr extends RingIntPExpr {
7272
override def add = (x: PV[Int], y: PV[Int]) => (x, y) match {
7373
case (Sta(0), y) => y
7474
case (x, Sta(0)) => x

tests/run-with-compiler/shonan-hmm/Test.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,36 @@ object Test {
6565
println()
6666

6767
println(MVmult.mvmult_ac(a).show)
68+
println()
69+
println()
70+
println()
71+
72+
println(MVmult.mvmult_opt(a).show)
73+
// FIXME
74+
// Caused by: java.lang.ArrayIndexOutOfBoundsException: 5
75+
// at MVmult$.$anonfun$52$$anonfun$7(MVmult.scala:99)
76+
// at Vec.apply(Vec.scala:4)
77+
// at Vec.zipWith$$anonfun$1(Vec.scala:9)
78+
// at VecRStaOptDynInt.loop$1(VecROp.scala:84)
79+
// at VecRStaOptDynInt.count_non_zeros(VecROp.scala:86)
80+
// at VecRStaOptDynInt.reduce$$anonfun$1(VecROp.scala:76)
81+
// at Blas2$Blas2VecOps.dot(Blas.scala:23)
82+
// at Blas2$Blas2MatOps.$times$$anonfun$1(Blas.scala:27)
83+
// at Vec.vecMap$$anonfun$1(Vec.scala:6)
84+
// at Vec.apply(Vec.scala:4)
85+
// at OVec.vecAssign$$anonfun$1(Vec.scala:14)
86+
// at VecRStaDyn.$anonfun$1$1(VecROp.scala:58)
87+
// at dotty.tools.dotc.core.tasty.TreeUnpickler$TreeReader.readHole(TreeUnpickler.scala:1179)
88+
// at dotty.tools.dotc.core.tasty.TreeUnpickler$TreeReader.readLengthTerm$1(TreeUnpickler.scala:1113)
89+
// at dotty.tools.dotc.core.tasty.TreeUnpickler$TreeReader.readTerm(TreeUnpickler.scala:1121)
90+
// at dotty.tools.dotc.core.tasty.TreeUnpickler$TreeReader.readBlock$1(TreeUnpickler.scala:1019)
91+
92+
93+
// println()
94+
// println()
95+
// println()
96+
//
97+
// println(MVmult.mvmult_roll(a).show)
6898

6999
}
70100
}

tests/run-with-compiler/shonan-hmm/VecROp.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,44 @@ class VecRStaDyn[T : Type : Liftable](r: Ring[PV[T]]) extends VecROp[PV[Int], PV
6565
}
6666
override def toString(): String = s"VecRStaDim($r)"
6767
}
68+
69+
class VecRStaOptDynInt(r: Ring[PV[Int]]) extends VecRStaDyn(r) {
70+
val M: VecROp[PV[Int], PV[Int], Expr[Unit]] = new VecRStaDyn(r)
71+
72+
private val threshold = 3
73+
74+
override def reduce: ((PV[Int], PV[Int]) => PV[Int], PV[Int], Vec[PV[Int], PV[Int]]) => PV[Int] = (plus, zero, vec) => vec match {
75+
case Vec(Sta(n), vecf) =>
76+
if (count_non_zeros(n, vecf) < threshold) M.reduce(plus, zero, vec)
77+
else M.reduce(plus, zero, Vec(Dyn(n.toExpr), vecf))
78+
case _ => M.reduce(plus, zero, vec)
79+
}
80+
81+
private def count_non_zeros(n: Int, vecf: PV[Int] => PV[Int]): Int = {
82+
def loop(i: Int, acc: Int): Int = {
83+
if (i >= n) acc
84+
else loop(i + 1, if (vecf(Sta(n)) == Sta(0)) acc else acc + 1)
85+
}
86+
loop(0, 0)
87+
}
88+
}
89+
90+
//module VecRStaOptDynFloat = struct
91+
//module R = RingFloatOPCode
92+
//module M = VecRStaDyn(Lift_float)
93+
//include M
94+
// let threshold = 3
95+
//let count_non_zeros n vecf =
96+
// let rec loop acc i =
97+
// if i ≥ n then acc else
98+
// let acc = if vecf (Sta i) = Sta 0. then acc else acc + 1 in
99+
// loop acc (i+1)
100+
//in loop 0 0
101+
//let reduce plus zero = function
102+
//| (Vec (Sta n,vecf)) as vec →
103+
//if count_non_zeros n vecf < threshold then
104+
// M.reduce plus zero vec
105+
//else
106+
// M.reduce plus zero (Vec (Dyn .<n>.,vecf))
107+
//| vec → M.reduce plus zero vec
108+
// end

0 commit comments

Comments
 (0)