@@ -202,3 +202,109 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar =
202
202
safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
203
203
lhs_opts .value , rhs_opts .value ))
204
204
return out
205
+
206
+ def gemm (lhs , rhs , alpha = 1.0 , beta = 0.0 , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE , C = None ):
207
+ """
208
+ BLAS general matrix multiply (GEMM) of two af_array objects.
209
+
210
+ This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as:
211
+
212
+ C = α ∗ opA(A) opB(B)+ β∗C
213
+
214
+ where α (alpha) and β (beta) are both scalars; A and B are the matrix multiply operands;
215
+ and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations
216
+ on A or B before the actual GEMM operation.
217
+ Batched GEMM is supported if at least either A or B have more than two dimensions
218
+ (see af::matmul for more details on broadcasting).
219
+ However, only one alpha and one beta can be used for all of the batched matrix operands.
220
+
221
+ Parameters
222
+ ----------
223
+
224
+ lhs : af.Array
225
+ A 2 dimensional, real or complex arrayfire array.
226
+
227
+ rhs : af.Array
228
+ A 2 dimensional, real or complex arrayfire array.
229
+
230
+ alpha : scalar
231
+
232
+ beta : scalar
233
+
234
+ lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
235
+ Can be one of
236
+ - af.MATPROP.NONE - If no op should be done on `lhs`.
237
+ - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying.
238
+ - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying.
239
+
240
+ rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
241
+ Can be one of
242
+ - af.MATPROP.NONE - If no op should be done on `rhs`.
243
+ - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying.
244
+ - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying.
245
+
246
+ Returns
247
+ -------
248
+
249
+ out : af.Array
250
+ Output of the matrix multiplication on `lhs` and `rhs`.
251
+
252
+ Note
253
+ -----
254
+
255
+ - The data types of `lhs` and `rhs` should be the same.
256
+ - Batches are not supported.
257
+
258
+ """
259
+ if C is None :
260
+ out = Array ()
261
+ else :
262
+ out = C
263
+
264
+ ltype = lhs .dtype ()
265
+
266
+ if ltype == Dtype .f32 :
267
+ aptr = c_cast (c_pointer (c_float_t (alpha )),c_void_ptr_t )
268
+ bptr = c_cast (c_pointer (c_float_t (beta )), c_void_ptr_t )
269
+ elif ltype == Dtype .c32 :
270
+ if isinstance (alpha , af_cfloat_t ):
271
+ aptr = c_cast (c_pointer (alpha ), c_void_ptr_t )
272
+ elif isinstance (alpha , tuple ):
273
+ aptr = c_cast (c_pointer (af_cfloat_t (alpha [0 ], alpha [1 ])), c_void_ptr_t )
274
+ else :
275
+ aptr = c_cast (c_pointer (af_cfloat_t (alpha )), c_void_ptr_t )
276
+
277
+ if isinstance (beta , af_cfloat_t ):
278
+ bptr = c_cast (c_pointer (beta ), c_void_ptr_t )
279
+ elif isinstance (beta , tuple ):
280
+ bptr = c_cast (c_pointer (af_cfloat_t (beta [0 ], beta [1 ])), c_void_ptr_t )
281
+ else :
282
+ bptr = c_cast (c_pointer (af_cfloat_t (beta )), c_void_ptr_t )
283
+
284
+ elif ltype == Dtype .f64 :
285
+ aptr = c_cast (c_pointer (c_double_t (alpha )),c_void_ptr_t )
286
+ bptr = c_cast (c_pointer (c_double_t (beta )), c_void_ptr_t )
287
+ elif ltype == Dtype .c64 :
288
+ if isinstance (alpha , af_cdouble_t ):
289
+ aptr = c_cast (c_pointer (alpha ), c_void_ptr_t )
290
+ elif isinstance (alpha , tuple ):
291
+ aptr = c_cast (c_pointer (af_cdouble_t (alpha [0 ], alpha [1 ])), c_void_ptr_t )
292
+ else :
293
+ aptr = c_cast (c_pointer (af_cdouble_t (alpha )), c_void_ptr_t )
294
+
295
+ if isinstance (beta , af_cdouble_t ):
296
+ bptr = c_cast (c_pointer (beta ), c_void_ptr_t )
297
+ elif isinstance (beta , tuple ):
298
+ bptr = c_cast (c_pointer (af_cdouble_t (beta [0 ], beta [1 ])), c_void_ptr_t )
299
+ else :
300
+ bptr = c_cast (c_pointer (af_cdouble_t (beta )), c_void_ptr_t )
301
+ elif ltype == Dtype .f16 :
302
+ raise TypeError ("fp16 currently unsupported gemm() input type" )
303
+ else :
304
+ raise TypeError ("unsupported input type" )
305
+
306
+
307
+ safe_call (backend .get ().af_gemm (c_pointer (out .arr ),
308
+ lhs_opts .value , rhs_opts .value ,
309
+ aptr , lhs .arr , rhs .arr , bptr ))
310
+ return out
0 commit comments