23
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
24
// *****************************************************************************
25
25
26
+ #pragma once
27
+
26
28
#include < pybind11/pybind11.h>
27
29
#include < pybind11/stl.h>
28
30
#include < sycl/sycl.hpp>
29
31
30
32
#include " dpctl4pybind11.hpp"
31
- #include " hamming_kernel.hpp"
32
33
#include " utils/output_validation.hpp"
33
34
#include " utils/type_dispatch.hpp"
35
+ #include " utils/type_utils.hpp"
34
36
35
37
namespace dpnp ::extensions::window
36
38
{
37
39
38
40
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
39
41
40
- static kernels::hamming_fn_ptr_t hamming_dispatch_table[dpctl_td_ns::num_types];
41
-
42
42
namespace py = pybind11;
43
43
44
+ typedef sycl::event (*window_fn_ptr_t )(sycl::queue &,
45
+ char *,
46
+ const std::size_t ,
47
+ const std::vector<sycl::event> &);
48
+
49
+ template <typename T, template <typename > class Functor >
50
+ sycl::event window_impl (sycl::queue &q,
51
+ char *result,
52
+ const std::size_t nelems,
53
+ const std::vector<sycl::event> &depends)
54
+ {
55
+ dpctl::tensor::type_utils::validate_type_for_device<T>(q);
56
+
57
+ T *res = reinterpret_cast <T *>(result);
58
+
59
+ sycl::event window_ev = q.submit ([&](sycl::handler &cgh) {
60
+ cgh.depends_on (depends);
61
+
62
+ using WindowKernel = Functor<T>;
63
+ cgh.parallel_for <WindowKernel>(sycl::range<1 >(nelems),
64
+ WindowKernel (res, nelems));
65
+ });
66
+
67
+ return window_ev;
68
+ }
69
+
70
+ template <typename dispatchT>
44
71
std::pair<sycl::event, sycl::event>
45
- py_hamming (sycl::queue &exec_q,
46
- const dpctl::tensor::usm_ndarray &result,
47
- const std::vector<sycl::event> &depends)
72
+ py_window (sycl::queue &exec_q,
73
+ const dpctl::tensor::usm_ndarray &result,
74
+ const std::vector<sycl::event> &depends,
75
+ const dispatchT &window_dispatch_vector)
48
76
{
49
77
dpctl::tensor::validation::CheckWritable::throw_if_not_writable (result);
50
78
@@ -71,52 +99,27 @@ std::pair<sycl::event, sycl::event>
71
99
int result_typenum = result.get_typenum ();
72
100
auto array_types = dpctl_td_ns::usm_ndarray_types ();
73
101
int result_type_id = array_types.typenum_to_lookup_id (result_typenum);
74
- auto fn = hamming_dispatch_table [result_type_id];
102
+ auto fn = window_dispatch_vector [result_type_id];
75
103
76
104
if (fn == nullptr ) {
77
105
throw std::runtime_error (" Type of given array is not supported" );
78
106
}
79
107
80
108
char *result_typeless_ptr = result.get_data ();
81
- sycl::event hamming_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
109
+ sycl::event window_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
82
110
sycl::event args_ev =
83
- dpctl::utils::keep_args_alive (exec_q, {result}, {hamming_ev });
111
+ dpctl::utils::keep_args_alive (exec_q, {result}, {window_ev });
84
112
85
- return std::make_pair (args_ev, hamming_ev );
113
+ return std::make_pair (args_ev, window_ev );
86
114
}
87
115
88
- template <typename fnT, typename T>
89
- struct HammingFactory
116
+ template <template < typename fnT, typename T> typename factoryT >
117
+ void init_window_dispatch_vectors ( window_fn_ptr_t window_dispatch_vector[])
90
118
{
91
- fnT get ()
92
- {
93
- if constexpr (std::is_floating_point_v<T>) {
94
- return kernels::hamming_impl<T>;
95
- }
96
- else {
97
- return nullptr ;
98
- }
99
- }
100
- };
101
-
102
- void init_hamming_dispatch_tables (void )
103
- {
104
- using kernels::hamming_fn_ptr_t ;
105
-
106
- dpctl_td_ns::DispatchVectorBuilder<hamming_fn_ptr_t , HammingFactory,
119
+ dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t , factoryT,
107
120
dpctl_td_ns::num_types>
108
121
contig;
109
- contig.populate_dispatch_vector (hamming_dispatch_table);
110
-
111
- return ;
112
- }
113
-
114
- void init_hamming (py::module_ m)
115
- {
116
- dpnp::extensions::window::init_hamming_dispatch_tables ();
117
-
118
- m.def (" _hamming" , &py_hamming, " Call hamming kernel" , py::arg (" sycl_queue" ),
119
- py::arg (" result" ), py::arg (" depends" ) = py::list ());
122
+ contig.populate_dispatch_vector (window_dispatch_vector);
120
123
121
124
return ;
122
125
}
0 commit comments