@@ -66,41 +66,59 @@ object MVmult {
66
66
case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
67
67
case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
68
68
}))
69
- mvmult_abs (a.length, a(0 ).length, a2)
69
+ mvmult_abs0( new RingIntPExpr , new VecRStaDyn ( new RingIntPExpr )) (a.length, a(0 ).length, a2)
70
70
}
71
71
}
72
72
}
73
73
74
- private def mvmult_abs (n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
74
+ def mvmult_opt (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
75
+ val n = a.length
76
+ val m = a(0 ).length
77
+ import Lifters ._
78
+ ' {
79
+ val arr = ~ a.toExpr
80
+ ~ {
81
+ val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
82
+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
83
+ case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
84
+ case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
85
+ }))
86
+ mvmult_abs0(new RingIntOPExpr , new VecRStaDyn (new RingIntPExpr ))(a.length, a(0 ).length, a2)
87
+ }
88
+ }
89
+ }
90
+
91
+ def mvmult_roll (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
92
+ val n = a.length
93
+ val m = a(0 ).length
94
+ import Lifters ._
95
+ ' {
96
+ val arr = ~ a.toExpr
97
+ ~ {
98
+ val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
99
+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
100
+ case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
101
+ case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
102
+ }))
103
+ mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(a.length, a(0 ).length, a2)
104
+ }
105
+ }
106
+ }
107
+
108
+ private def mvmult_abs0 (ring : Ring [PV [Int ]], vecOp : VecROp [PV [Int ], PV [Int ], Expr [Unit ]])(n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
75
109
' {
76
110
(vout, v) => {
77
111
if (~ n.toExpr != vout.length) throw new IndexOutOfBoundsException (~ n.toString.toExpr)
78
112
if (~ m.toExpr != v.length) throw new IndexOutOfBoundsException (~ m.toString.toExpr)
79
113
~ {
80
114
val vout_ : OVec [PV [Int ], PV [Int ], Expr [Unit ]] = OVec (Sta (n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x)))
81
115
val v_ : Vec [PV [Int ], PV [Int ]] = Vec (Sta (m), i => Dyn ('(v(~Dyns.dyni(i)))))
82
- val MV = new MVmult [PV [Int ], PV [Int ], Expr [Unit ]](new RingIntPExpr , new VecRStaDyn ( new RingIntPExpr ) )
116
+ val MV = new MVmult [PV [Int ], PV [Int ], Expr [Unit ]](ring, vecOp )
83
117
MV .mvmult(vout_, a, v_)
84
118
}
85
119
}
86
120
}
87
121
}
88
122
89
123
90
-
91
- // let mvmult_abs : _ →
92
- // amat → (float array → float array → unit) code =
93
- // fun mvmult → fun {n;m;a} →
94
- // .<fun vout v →
95
- // assert (n = Array.length vout && m = Array.length v);
96
- // .~(let vout = OVec (Sta n, fun i v → .<vout.(.~(dyni i)) ← .~(dynf v)>.) in
97
- // let v = Vec (Sta m, fun j → Dyn .<v.(.~(dyni j))>.) in
98
- // mvmult vout a v)
99
- // >.
100
- // val mvmult_abs :
101
- // ((int pv, float pv, unit code) Vector.ovec →
102
- // (int pv, (int pv, float pv) Vector.vec) Vector.vec →
103
- // (int pv, float pv) Vector.vec → unit code) →
104
- // amat → (float array → float array → unit) code = <fun>
105
-
106
124
}
0 commit comments