Skip to content

Commit 2700d92

Browse files
authored
create a common file to facilitate the implementation of future window functions. (#2357)
In this PR, a common file has been created to facilitate the implementation of future window functions. In addition a few issues in documentation are addressed.
1 parent e6daa2d commit 2700d92

12 files changed

+119
-131
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11-
* Added implementation of `dpnp.hamming` [#2341](https://github.com/IntelPython/dpnp/pull/2341)
11+
* Added implementation of `dpnp.hamming` [#2341](https://github.com/IntelPython/dpnp/pull/2341), [#2357](https://github.com/IntelPython/dpnp/pull/2357)
1212

1313
### Changed
1414

doc/known_words.txt

+8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ al
22
ary
33
backend
44
bandlimited
5+
bincount
56
bitwise
7+
Blackman
68
boolean
79
broadcastable
810
broadcasted
@@ -36,6 +38,7 @@ fs
3638
getter
3739
Golub
3840
Hadamard
41+
Hanning
3942
histogrammed
4043
Hypergeometric
4144
kwargs
@@ -49,9 +52,12 @@ Lanczos
4952
Lomax
5053
Mersenne
5154
meshgrid
55+
minlength
5256
Mises
5357
multinomial
5458
multivalued
59+
namespace
60+
namedtuple
5561
NaN
5662
NaT
5763
nd
@@ -69,6 +75,7 @@ Nyquist
6975
oneAPI
7076
ord
7177
orthonormal
78+
radix
7279
Penrose
7380
Polyutils
7481
pre
@@ -79,6 +86,7 @@ representable
7986
resampling
8087
runtimes
8188
scikit
89+
se
8290
signbit
8391
signum
8492
sinc

doc/reference/linalg.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Other matrix operations
9797

9898
dpnp.diagonal
9999
dpnp.linalg.diagonal (Array API compatible)
100-
dpnp.linalg.matrix_tranpose (Array API compatible)
100+
dpnp.linalg.matrix_transpose (Array API compatible)
101101

102102
Exceptions
103103
----------

doc/reference/logic.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Array type testing
4444

4545

4646
Logical operations
47-
----------------
47+
------------------
4848

4949
.. autosummary::
5050
:toctree: generated/

doc/reference/ndarray.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Constructing arrays
1818
-------------------
1919

2020
New arrays can be constructed using the routines detailed in
21-
:ref:`Array Creation Routines <routines.creation>`, and also by using the low-level
21+
:ref:`Array Creation Routines <routines.array-creation>`, and also by using the low-level
2222
:class:`dpnp.ndarray` constructor:
2323

2424
.. autosummary::

dpnp/backend/extensions/blas/blas_py.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525
//
26-
// This file defines functions of dpnp.backend._lapack_impl extensions
26+
// This file defines functions of dpnp.backend._blas_impl extensions
2727
//
2828
//*****************************************************************************
2929

dpnp/backend/extensions/window/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
set(python_module_name _window_impl)
2828
set(_module_src
29-
${CMAKE_CURRENT_SOURCE_DIR}/hamming.cpp
3029
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
3130
)
3231

dpnp/backend/extensions/window/hamming.cpp renamed to dpnp/backend/extensions/window/common.hpp

+42-39
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,56 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26+
#pragma once
27+
2628
#include <pybind11/pybind11.h>
2729
#include <pybind11/stl.h>
2830
#include <sycl/sycl.hpp>
2931

3032
#include "dpctl4pybind11.hpp"
31-
#include "hamming_kernel.hpp"
3233
#include "utils/output_validation.hpp"
3334
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
3436

3537
namespace dpnp::extensions::window
3638
{
3739

3840
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3941

40-
static kernels::hamming_fn_ptr_t hamming_dispatch_table[dpctl_td_ns::num_types];
41-
4242
namespace py = pybind11;
4343

44+
typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
45+
char *,
46+
const std::size_t,
47+
const std::vector<sycl::event> &);
48+
49+
template <typename T, template <typename> class Functor>
50+
sycl::event window_impl(sycl::queue &q,
51+
char *result,
52+
const std::size_t nelems,
53+
const std::vector<sycl::event> &depends)
54+
{
55+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
56+
57+
T *res = reinterpret_cast<T *>(result);
58+
59+
sycl::event window_ev = q.submit([&](sycl::handler &cgh) {
60+
cgh.depends_on(depends);
61+
62+
using WindowKernel = Functor<T>;
63+
cgh.parallel_for<WindowKernel>(sycl::range<1>(nelems),
64+
WindowKernel(res, nelems));
65+
});
66+
67+
return window_ev;
68+
}
69+
70+
template <typename dispatchT>
4471
std::pair<sycl::event, sycl::event>
45-
py_hamming(sycl::queue &exec_q,
46-
const dpctl::tensor::usm_ndarray &result,
47-
const std::vector<sycl::event> &depends)
72+
py_window(sycl::queue &exec_q,
73+
const dpctl::tensor::usm_ndarray &result,
74+
const std::vector<sycl::event> &depends,
75+
const dispatchT &window_dispatch_vector)
4876
{
4977
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
5078

@@ -71,52 +99,27 @@ std::pair<sycl::event, sycl::event>
7199
int result_typenum = result.get_typenum();
72100
auto array_types = dpctl_td_ns::usm_ndarray_types();
73101
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
74-
auto fn = hamming_dispatch_table[result_type_id];
102+
auto fn = window_dispatch_vector[result_type_id];
75103

76104
if (fn == nullptr) {
77105
throw std::runtime_error("Type of given array is not supported");
78106
}
79107

80108
char *result_typeless_ptr = result.get_data();
81-
sycl::event hamming_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
109+
sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
82110
sycl::event args_ev =
83-
dpctl::utils::keep_args_alive(exec_q, {result}, {hamming_ev});
111+
dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});
84112

85-
return std::make_pair(args_ev, hamming_ev);
113+
return std::make_pair(args_ev, window_ev);
86114
}
87115

88-
template <typename fnT, typename T>
89-
struct HammingFactory
116+
template <template <typename fnT, typename T> typename factoryT>
117+
void init_window_dispatch_vectors(window_fn_ptr_t window_dispatch_vector[])
90118
{
91-
fnT get()
92-
{
93-
if constexpr (std::is_floating_point_v<T>) {
94-
return kernels::hamming_impl<T>;
95-
}
96-
else {
97-
return nullptr;
98-
}
99-
}
100-
};
101-
102-
void init_hamming_dispatch_tables(void)
103-
{
104-
using kernels::hamming_fn_ptr_t;
105-
106-
dpctl_td_ns::DispatchVectorBuilder<hamming_fn_ptr_t, HammingFactory,
119+
dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t, factoryT,
107120
dpctl_td_ns::num_types>
108121
contig;
109-
contig.populate_dispatch_vector(hamming_dispatch_table);
110-
111-
return;
112-
}
113-
114-
void init_hamming(py::module_ m)
115-
{
116-
dpnp::extensions::window::init_hamming_dispatch_tables();
117-
118-
m.def("_hamming", &py_hamming, "Call hamming kernel", py::arg("sycl_queue"),
119-
py::arg("result"), py::arg("depends") = py::list());
122+
contig.populate_dispatch_vector(window_dispatch_vector);
120123

121124
return;
122125
}

dpnp/backend/extensions/window/hamming.hpp

+36-5
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,42 @@
2525

2626
#pragma once
2727

28-
#include <pybind11/pybind11.h>
28+
#include "common.hpp"
29+
#include <sycl/sycl.hpp>
2930

30-
namespace py = pybind11;
31+
namespace dpnp::extensions::window::kernels
32+
{
3133

32-
namespace dpnp::extensions::window
34+
template <typename T>
35+
class HammingFunctor
3336
{
34-
void init_hamming(py::module_ m);
35-
}
37+
private:
38+
T *data = nullptr;
39+
const std::size_t N;
40+
41+
public:
42+
HammingFunctor(T *data, const std::size_t N) : data(data), N(N) {}
43+
44+
void operator()(sycl::id<1> id) const
45+
{
46+
const auto i = id.get(0);
47+
48+
data[i] = T(0.54) - T(0.46) * sycl::cospi(T(2) * i / (N - 1));
49+
}
50+
};
51+
52+
template <typename fnT, typename T>
53+
struct HammingFactory
54+
{
55+
fnT get()
56+
{
57+
if constexpr (std::is_floating_point_v<T>) {
58+
return window_impl<T, HammingFunctor>;
59+
}
60+
else {
61+
return nullptr;
62+
}
63+
}
64+
};
65+
66+
} // namespace dpnp::extensions::window::kernels

dpnp/backend/extensions/window/hamming_kernel.hpp

-79
This file was deleted.

dpnp/backend/extensions/window/window_py.cpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,36 @@
2828
//*****************************************************************************
2929

3030
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
3132

33+
#include "common.hpp"
3234
#include "hamming.hpp"
3335

36+
namespace window_ns = dpnp::extensions::window;
37+
namespace py = pybind11;
38+
using window_ns::window_fn_ptr_t;
39+
40+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
41+
42+
static window_fn_ptr_t hamming_dispatch_vector[dpctl_td_ns::num_types];
43+
3444
PYBIND11_MODULE(_window_impl, m)
3545
{
36-
dpnp::extensions::window::init_hamming(m);
46+
using arrayT = dpctl::tensor::usm_ndarray;
47+
using event_vecT = std::vector<sycl::event>;
48+
49+
{
50+
window_ns::init_window_dispatch_vectors<
51+
window_ns::kernels::HammingFactory>(hamming_dispatch_vector);
52+
53+
auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
54+
const event_vecT &depends = {}) {
55+
return window_ns::py_window(exec_q, result, depends,
56+
hamming_dispatch_vector);
57+
};
58+
59+
m.def("_hamming", hamming_pyapi, "Call hamming kernel",
60+
py::arg("sycl_queue"), py::arg("result"),
61+
py::arg("depends") = py::list());
62+
}
3763
}

dpnp/dpnp_iface_nanfunctions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _replace_nan(a, val):
108108
109109
Returns
110110
-------
111-
out : {dpnp.ndarray}
111+
out : dpnp.ndarray
112112
If `a` is of inexact type, return a copy of `a` with the NaNs
113113
replaced by the fill value, otherwise return `a`.
114114
mask: {bool, None}

0 commit comments

Comments
 (0)