Skip to content

Commit a914b02

Browse files
syurkevi9prady9
authored andcommitted
adds gemm functionality, complex ctypes
1 parent e053bb5 commit a914b02

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

arrayfire/blas.py

+106
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,109 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar =
202202
safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr,
203203
lhs_opts.value, rhs_opts.value))
204204
return out
205+
206+
def gemm(lhs, rhs, alpha=1.0, beta=0.0, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, C=None):
207+
"""
208+
BLAS general matrix multiply (GEMM) of two af_array objects.
209+
210+
This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as:
211+
212+
C = α ∗ opA(A) opB(B)+ β∗C
213+
214+
where α (alpha) and β (beta) are both scalars; A and B are the matrix multiply operands;
215+
and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations
216+
on A or B before the actual GEMM operation.
217+
Batched GEMM is supported if at least either A or B have more than two dimensions
218+
(see af::matmul for more details on broadcasting).
219+
However, only one alpha and one beta can be used for all of the batched matrix operands.
220+
221+
Parameters
222+
----------
223+
224+
lhs : af.Array
225+
A 2 dimensional, real or complex arrayfire array.
226+
227+
rhs : af.Array
228+
A 2 dimensional, real or complex arrayfire array.
229+
230+
alpha : scalar
231+
232+
beta : scalar
233+
234+
lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
235+
Can be one of
236+
- af.MATPROP.NONE - If no op should be done on `lhs`.
237+
- af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying.
238+
- af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying.
239+
240+
rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
241+
Can be one of
242+
- af.MATPROP.NONE - If no op should be done on `rhs`.
243+
- af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying.
244+
- af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying.
245+
246+
Returns
247+
-------
248+
249+
out : af.Array
250+
Output of the matrix multiplication on `lhs` and `rhs`.
251+
252+
Note
253+
-----
254+
255+
- The data types of `lhs` and `rhs` should be the same.
256+
- Batches are not supported.
257+
258+
"""
259+
if C is None:
260+
out = Array()
261+
else:
262+
out = C
263+
264+
ltype = lhs.dtype()
265+
266+
if ltype == Dtype.f32:
267+
aptr = c_cast(c_pointer(c_float_t(alpha)),c_void_ptr_t)
268+
bptr = c_cast(c_pointer(c_float_t(beta)), c_void_ptr_t)
269+
elif ltype == Dtype.c32:
270+
if isinstance(alpha, af_cfloat_t):
271+
aptr = c_cast(c_pointer(alpha), c_void_ptr_t)
272+
elif isinstance(alpha, tuple):
273+
aptr = c_cast(c_pointer(af_cfloat_t(alpha[0], alpha[1])), c_void_ptr_t)
274+
else:
275+
aptr = c_cast(c_pointer(af_cfloat_t(alpha)), c_void_ptr_t)
276+
277+
if isinstance(beta, af_cfloat_t):
278+
bptr = c_cast(c_pointer(beta), c_void_ptr_t)
279+
elif isinstance(beta, tuple):
280+
bptr = c_cast(c_pointer(af_cfloat_t(beta[0], beta[1])), c_void_ptr_t)
281+
else:
282+
bptr = c_cast(c_pointer(af_cfloat_t(beta)), c_void_ptr_t)
283+
284+
elif ltype == Dtype.f64:
285+
aptr = c_cast(c_pointer(c_double_t(alpha)),c_void_ptr_t)
286+
bptr = c_cast(c_pointer(c_double_t(beta)), c_void_ptr_t)
287+
elif ltype == Dtype.c64:
288+
if isinstance(alpha, af_cdouble_t):
289+
aptr = c_cast(c_pointer(alpha), c_void_ptr_t)
290+
elif isinstance(alpha, tuple):
291+
aptr = c_cast(c_pointer(af_cdouble_t(alpha[0], alpha[1])), c_void_ptr_t)
292+
else:
293+
aptr = c_cast(c_pointer(af_cdouble_t(alpha)), c_void_ptr_t)
294+
295+
if isinstance(beta, af_cdouble_t):
296+
bptr = c_cast(c_pointer(beta), c_void_ptr_t)
297+
elif isinstance(beta, tuple):
298+
bptr = c_cast(c_pointer(af_cdouble_t(beta[0], beta[1])), c_void_ptr_t)
299+
else:
300+
bptr = c_cast(c_pointer(af_cdouble_t(beta)), c_void_ptr_t)
301+
elif ltype == Dtype.f16:
302+
raise TypeError("fp16 currently unsupported gemm() input type")
303+
else:
304+
raise TypeError("unsupported input type")
305+
306+
307+
safe_call(backend.get().af_gemm(c_pointer(out.arr),
308+
lhs_opts.value, rhs_opts.value,
309+
aptr, lhs.arr, rhs.arr, bptr))
310+
return out

arrayfire/library.py

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
c_void_ptr_t = ct.c_void_p
3232
c_char_ptr_t = ct.c_char_p
3333
c_size_t = ct.c_size_t
34+
c_cast = ct.cast
35+
36+
class af_cfloat_t(ct.Structure):
37+
_fields_ = [("real", ct.c_float), ("imag", ct.c_float)]
38+
39+
class af_cdouble_t(ct.Structure):
40+
_fields_ = [("real", ct.c_double), ("imag", ct.c_double)]
3441

3542

3643
AF_VER_MAJOR = '3'

0 commit comments

Comments
 (0)