Skip to content

Commit ec86afa

Browse files
authored
Add 3.7 features to python wrapper (#221)
* adds af_pad to python wrapper * adds meanvar to python wrapper * adds inverse square root to python wrapper * adds pinverse to python wrapper * adds NN convolve and gradient functions to wrapper * adds reduce by key to python wrapper missing convolve gradient function * adds confidenceCC to python wrapper * adds fp16 support to python wrapper * update version * remove stray print statements * adds axes_label_format to python wrapper, removes mistakenly copied code
1 parent aead039 commit ec86afa

23 files changed

+636
-7
lines changed

__af_version__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
# http://arrayfire.com/licenses/BSD-3-Clause
1010
########################################################
1111

12-
version = "3.5"
13-
release = "20170718"
12+
version = "3.7"
13+
release = "20200213"
1414
full_version = version + "." + release

arrayfire/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from .timer import *
7575
from .random import *
7676
from .sparse import *
77+
from .ml import *
7778

7879
# do not export default modules as part of arrayfire
7980
del ct

arrayfire/algorithm.py

+190
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,31 @@ def _nan_reduce_all(a, c_func, nan_val):
4444
imag = imag.value
4545
return real if imag == 0 else real + imag * 1j
4646

47+
def _FNSD(dim, dims):
48+
if dim >= 0:
49+
return int(dim)
50+
51+
fnsd = 0
52+
for i, d in enumerate(dims):
53+
if d > 1:
54+
fnsd = i
55+
break
56+
return int(fnsd)
57+
58+
def _rbk_dim(keys, vals, dim, c_func):
59+
keys_out = Array()
60+
vals_out = Array()
61+
rdim = _FNSD(dim, vals.dims())
62+
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
63+
return keys_out, vals_out
64+
65+
def _nan_rbk_dim(a, dim, c_func, nan_val):
66+
keys_out = Array()
67+
vals_out = Array()
68+
rdim = _FNSD(dim, vals.dims())
69+
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
70+
return keys_out, vals_out
71+
4772
def sum(a, dim=None, nan_val=None):
4873
"""
4974
Calculate the sum of all the elements along a specified dimension.
@@ -74,6 +99,34 @@ def sum(a, dim=None, nan_val=None):
7499
else:
75100
return _reduce_all(a, backend.get().af_sum_all)
76101

102+
103+
def sumByKey(keys, vals, dim=-1, nan_val=None):
104+
"""
105+
Calculate the sum of elements along a specified dimension according to a key.
106+
107+
Parameters
108+
----------
109+
keys : af.Array
110+
One dimensional arrayfire array with reduction keys.
111+
vals : af.Array
112+
Multi dimensional arrayfire array that will be reduced.
113+
dim: optional: int. default: -1
114+
Dimension along which the sum will occur.
115+
nan_val: optional: scalar. default: None
116+
The value that replaces NaN in the array
117+
118+
Returns
119+
-------
120+
keys: af.Array or scalar number
121+
The reduced keys of all elements in `vals` along dimension `dim`.
122+
values: af.Array or scalar number
123+
The sum of all elements in `vals` along dimension `dim` according to keys
124+
"""
125+
if (nan_val is not None):
126+
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
127+
else:
128+
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
129+
77130
def product(a, dim=None, nan_val=None):
78131
"""
79132
Calculate the product of all the elements along a specified dimension.
@@ -104,6 +157,33 @@ def product(a, dim=None, nan_val=None):
104157
else:
105158
return _reduce_all(a, backend.get().af_product_all)
106159

160+
def productByKey(keys, vals, dim=-1, nan_val=None):
161+
"""
162+
Calculate the product of elements along a specified dimension according to a key.
163+
164+
Parameters
165+
----------
166+
keys : af.Array
167+
One dimensional arrayfire array with reduction keys.
168+
vals : af.Array
169+
Multi dimensional arrayfire array that will be reduced.
170+
dim: optional: int. default: -1
171+
Dimension along which the product will occur.
172+
nan_val: optional: scalar. default: None
173+
The value that replaces NaN in the array
174+
175+
Returns
176+
-------
177+
keys: af.Array or scalar number
178+
The reduced keys of all elements in `vals` along dimension `dim`.
179+
values: af.Array or scalar number
180+
The product of all elements in `vals` along dimension `dim` according to keys
181+
"""
182+
if (nan_val is not None):
183+
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
184+
else:
185+
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
186+
107187
def min(a, dim=None):
108188
"""
109189
Find the minimum value of all the elements along a specified dimension.
@@ -126,6 +206,28 @@ def min(a, dim=None):
126206
else:
127207
return _reduce_all(a, backend.get().af_min_all)
128208

209+
def minByKey(keys, vals, dim=-1):
210+
"""
211+
Calculate the min of elements along a specified dimension according to a key.
212+
213+
Parameters
214+
----------
215+
keys : af.Array
216+
One dimensional arrayfire array with reduction keys.
217+
vals : af.Array
218+
Multi dimensional arrayfire array that will be reduced.
219+
dim: optional: int. default: -1
220+
Dimension along which the min will occur.
221+
222+
Returns
223+
-------
224+
keys: af.Array or scalar number
225+
The reduced keys of all elements in `vals` along dimension `dim`.
226+
values: af.Array or scalar number
227+
The min of all elements in `vals` along dimension `dim` according to keys
228+
"""
229+
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)
230+
129231
def max(a, dim=None):
130232
"""
131233
Find the maximum value of all the elements along a specified dimension.
@@ -148,6 +250,28 @@ def max(a, dim=None):
148250
else:
149251
return _reduce_all(a, backend.get().af_max_all)
150252

253+
def maxByKey(keys, vals, dim=-1):
254+
"""
255+
Calculate the max of elements along a specified dimension according to a key.
256+
257+
Parameters
258+
----------
259+
keys : af.Array
260+
One dimensional arrayfire array with reduction keys.
261+
vals : af.Array
262+
Multi dimensional arrayfire array that will be reduced.
263+
dim: optional: int. default: -1
264+
Dimension along which the max will occur.
265+
266+
Returns
267+
-------
268+
keys: af.Array or scalar number
269+
The reduced keys of all elements in `vals` along dimension `dim`.
270+
values: af.Array or scalar number
271+
The max of all elements in `vals` along dimension `dim` according to keys.
272+
"""
273+
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)
274+
151275
def all_true(a, dim=None):
152276
"""
153277
Check if all the elements along a specified dimension are true.
@@ -170,6 +294,28 @@ def all_true(a, dim=None):
170294
else:
171295
return _reduce_all(a, backend.get().af_all_true_all)
172296

297+
def allTrueByKey(keys, vals, dim=-1):
298+
"""
299+
Calculate if all elements are true along a specified dimension according to a key.
300+
301+
Parameters
302+
----------
303+
keys : af.Array
304+
One dimensional arrayfire array with reduction keys.
305+
vals : af.Array
306+
Multi dimensional arrayfire array that will be reduced.
307+
dim: optional: int. default: -1
308+
Dimension along which the all true check will occur.
309+
310+
Returns
311+
-------
312+
keys: af.Array or scalar number
313+
The reduced keys of all true check in `vals` along dimension `dim`.
314+
values: af.Array or scalar number
315+
Booleans denoting if all elements are true in `vals` along dimension `dim` according to keys
316+
"""
317+
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)
318+
173319
def any_true(a, dim=None):
174320
"""
175321
Check if any the elements along a specified dimension are true.
@@ -192,6 +338,28 @@ def any_true(a, dim=None):
192338
else:
193339
return _reduce_all(a, backend.get().af_any_true_all)
194340

341+
def anyTrueByKey(keys, vals, dim=-1):
342+
"""
343+
Calculate if any elements are true along a specified dimension according to a key.
344+
345+
Parameters
346+
----------
347+
keys : af.Array
348+
One dimensional arrayfire array with reduction keys.
349+
vals : af.Array
350+
Multi dimensional arrayfire array that will be reduced.
351+
dim: optional: int. default: -1
352+
Dimension along which the any true check will occur.
353+
354+
Returns
355+
-------
356+
keys: af.Array or scalar number
357+
The reduced keys of any true check in `vals` along dimension `dim`.
358+
values: af.Array or scalar number
359+
Booleans denoting if any elements are true in `vals` along dimension `dim` according to keys.
360+
"""
361+
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)
362+
195363
def count(a, dim=None):
196364
"""
197365
Count the number of non zero elements in an array along a specified dimension.
@@ -214,6 +382,28 @@ def count(a, dim=None):
214382
else:
215383
return _reduce_all(a, backend.get().af_count_all)
216384

385+
def countByKey(keys, vals, dim=-1):
386+
"""
387+
Counts non-zero elements along a specified dimension according to a key.
388+
389+
Parameters
390+
----------
391+
keys : af.Array
392+
One dimensional arrayfire array with reduction keys.
393+
vals : af.Array
394+
Multi dimensional arrayfire array that will be reduced.
395+
dim: optional: int. default: -1
396+
Dimension along which to count elements.
397+
398+
Returns
399+
-------
400+
keys: af.Array or scalar number
401+
The reduced keys of count in `vals` along dimension `dim`.
402+
values: af.Array or scalar number
403+
Count of non-zero elements in `vals` along dimension `dim` according to keys.
404+
"""
405+
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)
406+
217407
def imin(a, dim=None):
218408
"""
219409
Find the value and location of the minimum value along a specified dimension

arrayfire/arith.py

+20
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,26 @@ def sqrt(a):
958958
"""
959959
return _arith_unary_func(a, backend.get().af_sqrt)
960960

961+
def rsqrt(a):
962+
"""
963+
Reciprocal or inverse square root of each element in the array.
964+
965+
Parameters
966+
----------
967+
a : af.Array
968+
Multi dimensional arrayfire array.
969+
970+
Returns
971+
--------
972+
out : af.Array
973+
array containing the inverse square root of each value from `a`.
974+
975+
Note
976+
-------
977+
`a` must not be complex.
978+
"""
979+
return _arith_unary_func(a, backend.get().af_rsqrt)
980+
961981
def cbrt(a):
962982
"""
963983
Cube root of each element in the array.

arrayfire/array.py

+8
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,14 @@ def is_single(self):
783783
safe_call(backend.get().af_is_single(c_pointer(res), self.arr))
784784
return res.value
785785

786+
def is_half(self):
787+
"""
788+
Check if the array is of half floating point type (fp16).
789+
"""
790+
res = c_bool_t(False)
791+
safe_call(backend.get().af_is_half(c_pointer(res), self.arr))
792+
return res.value
793+
786794
def is_real_floating(self):
787795
"""
788796
Check if the array is real and of floating point type.

arrayfire/data.py

+52
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,58 @@ def replace(lhs, cond, rhs):
799799
else:
800800
safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, c_double_t(rhs)))
801801

802+
def pad(a, beginPadding, endPadding, padFillType = PAD.ZERO):
803+
"""
804+
Pad an array
805+
806+
This function will pad an array with the specified border size.
807+
Newly padded values can be filled in several different ways.
808+
809+
Parameters
810+
----------
811+
812+
a: af.Array
813+
A multi dimensional input arrayfire array.
814+
815+
beginPadding: tuple of ints. default: (0, 0, 0, 0).
816+
817+
endPadding: tuple of ints. default: (0, 0, 0, 0).
818+
819+
padFillType: optional af.PAD default: af.PAD.ZERO
820+
specifies type of values to fill padded border with
821+
822+
Returns
823+
-------
824+
output: af.Array
825+
A padded array
826+
827+
Examples
828+
---------
829+
>>> import arrayfire as af
830+
>>> a = af.randu(3,3)
831+
>>> af.display(a)
832+
[3 3 1 1]
833+
0.4107 0.1794 0.3775
834+
0.8224 0.4198 0.3027
835+
0.9518 0.0081 0.6456
836+
837+
>>> padded = af.pad(a, (1, 1), (1, 1), af.ZERO)
838+
>>> af.display(padded)
839+
[5 5 1 1]
840+
0.0000 0.0000 0.0000 0.0000 0.0000
841+
0.0000 0.4107 0.1794 0.3775 0.0000
842+
0.0000 0.8224 0.4198 0.3027 0.0000
843+
0.0000 0.9518 0.0081 0.6456 0.0000
844+
0.0000 0.0000 0.0000 0.0000 0.0000
845+
"""
846+
out = Array()
847+
begin_dims = dim4(beginPadding[0], beginPadding[1], beginPadding[2], beginPadding[3])
848+
end_dims = dim4(endPadding[0], endPadding[1], endPadding[2], endPadding[3])
849+
850+
safe_call(backend.get().af_pad(c_pointer(out.arr), a.arr, 4, c_pointer(begin_dims), 4, c_pointer(end_dims), padFillType.value))
851+
return out
852+
853+
802854
def lookup(a, idx, dim=0):
803855
"""
804856
Lookup the values of input array based on index.

arrayfire/device.py

+19
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,25 @@ def is_dbl_supported(device=None):
150150
safe_call(backend.get().af_get_dbl_support(c_pointer(res), dev))
151151
return res.value
152152

153+
def is_half_supported(device=None):
154+
"""
155+
Check if half precision is supported on specified device.
156+
157+
Parameters
158+
-----------
159+
device: optional: int. default: None.
160+
id of the desired device.
161+
162+
Returns
163+
--------
164+
- True if half precision supported.
165+
- False if half precision not supported.
166+
"""
167+
dev = device if device is not None else get_device()
168+
res = c_bool_t(False)
169+
safe_call(backend.get().af_get_half_support(c_pointer(res), dev))
170+
return res.value
171+
153172
def sync(device=None):
154173
"""
155174
Block until all the functions on the device have completed execution.

0 commit comments

Comments
 (0)