@@ -31,7 +31,7 @@ object MVmult {
31
31
val a_ = Vec ('(n), (i: Expr[Int]) => Vec( ' (m), (j : Expr [Int ]) => ' { a(~ i)(~ j) } ))
32
32
val v_ = Vec ('(m), (i: Expr[Int]) => ' (v(~ i)))
33
33
34
- val MV = new MVmult [Expr [Int ], Expr [Int ], Expr [Unit ]](RingIntExpr , new VecRDyn ( RingIntExpr ) )
34
+ val MV = new MVmult [Expr [Int ], Expr [Int ], Expr [Unit ]](RingIntExpr , new VecRDyn )
35
35
MV .mvmult(vout_, a_, v_)
36
36
}
37
37
}
@@ -41,7 +41,8 @@ object MVmult {
41
41
val MV = new MVmult [Int , Expr [Int ], Expr [Unit ]](RingIntExpr , new VecRStaDim (RingIntExpr ))
42
42
' {
43
43
(vout, a, v) => {
44
- assert (~ n.toExpr == vout.length && ~ m.toExpr == v.length)
44
+ if (~ n.toExpr != vout.length) throw new IndexOutOfBoundsException (~ n.toString.toExpr)
45
+ if (~ m.toExpr != v.length) throw new IndexOutOfBoundsException (~ m.toString.toExpr)
45
46
~ {
46
47
val vout_ = OVec (n, (i, x : Expr [Int ]) => '(vout(~i.toExpr) = ~x))
47
48
val a_ = Vec (n, i => Vec (m, j => ' { a(~ i.toExpr)(~ j.toExpr) } ))
@@ -56,32 +57,50 @@ object MVmult {
56
57
def mvmult_ac (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
57
58
val n = a.length
58
59
val m = a(0 ).length
60
+ import Lifters ._
61
+ ' {
62
+ val arr = ~ a.toExpr
63
+ ~ {
64
+ val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
65
+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
66
+ case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
67
+ case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
68
+ }))
69
+ mvmult_abs(a.length, a(0 ).length, a2)
70
+ }
71
+ }
72
+ }
59
73
60
- // Array lifters
61
-
62
-
74
+ private def mvmult_abs (n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
63
75
' {
64
- val arr = Array ( // FIXMR lift a
65
- Array ( 5 , 0 , 0 , 5 , 0 ),
66
- Array ( 0 , 0 , 10 , 0 , 0 ),
67
- Array ( 0 , 10 , 0 , 0 , 0 ),
68
- Array ( 0 , 0 , 2 , 3 , 5 ),
69
- Array ( 0 , 0 , 3 , 0 , 7 )
70
- )
71
76
(vout, v) => {
72
- assert (~ n.toExpr == vout.length && ~ m.toExpr == v.length)
77
+ if (~ n.toExpr != vout.length) throw new IndexOutOfBoundsException (~ n.toString.toExpr)
78
+ if (~ m.toExpr != v.length) throw new IndexOutOfBoundsException (~ m.toString.toExpr)
73
79
~ {
74
- val vout_ : OVec [PV [Int ], Expr [Int ], Expr [Unit ]] = OVec (Sta (n), (i, x) => '(vout(~Dyns.dyni(i)) = ~x))
75
- val a2 : Vec [PV [Int ], Vec [PV [Int ], Expr [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => Dyns .dyn((i, j) match {
76
- case (Sta (i), Sta (j)) => Sta (a(i)(j))
77
- case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
78
- case (i, j) => Dyn (' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
79
- })))
80
- val v_ : Vec [PV [Int ], Expr [Int ]] = Vec (Sta (m), i => '(v(~Dyns.dyni(i))))
81
- val MV = new MVmult [PV [Int ], Expr [Int ], Expr [Unit ]](RingIntExpr , new VecRStaDyn (RingIntExpr ))
82
- MV .mvmult(vout_, a2, v_)
80
+ val vout_ : OVec [PV [Int ], PV [Int ], Expr [Unit ]] = OVec (Sta (n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x)))
81
+ val v_ : Vec [PV [Int ], PV [Int ]] = Vec (Sta (m), i => Dyn ('(v(~Dyns.dyni(i)))))
82
+ val MV = new MVmult [PV [Int ], PV [Int ], Expr [Unit ]](new RingIntPExpr , new VecRStaDyn (new RingIntPExpr ))
83
+ MV .mvmult(vout_, a, v_)
83
84
}
84
85
}
85
86
}
86
87
}
88
+
89
+
90
+
91
+ // let mvmult_abs : _ →
92
+ // amat → (float array → float array → unit) code =
93
+ // fun mvmult → fun {n;m;a} →
94
+ // .<fun vout v →
95
+ // assert (n = Array.length vout && m = Array.length v);
96
+ // .~(let vout = OVec (Sta n, fun i v → .<vout.(.~(dyni i)) ← .~(dynf v)>.) in
97
+ // let v = Vec (Sta m, fun j → Dyn .<v.(.~(dyni j))>.) in
98
+ // mvmult vout a v)
99
+ // >.
100
+ // val mvmult_abs :
101
+ // ((int pv, float pv, unit code) Vector.ovec →
102
+ // (int pv, (int pv, float pv) Vector.vec) Vector.vec →
103
+ // (int pv, float pv) Vector.vec → unit code) →
104
+ // amat → (float array → float array → unit) code = <fun>
105
+
87
106
}
0 commit comments