Skip to content

Commit e46cd64

Browse files
committed
ENH: Implement broadcast_to function
1 parent fbdd5c7 commit e46cd64

File tree

4 files changed

+117
-11
lines changed

4 files changed

+117
-11
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
)
1616
from ._ops import (
1717
add,
18+
broadcast_to,
1819
reshape,
1920
)
2021

2122
__all__ = [
2223
"add",
24+
"broadcast_to",
2325
"asarray",
2426
"asdtype",
2527
"reshape",

sparse/mlir_backend/_constructors.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,20 +310,20 @@ def __init__(
310310

311311
if obj.format in ("csr", "csc"):
312312
order = "r" if obj.format == "csr" else "c"
313-
index_dtype = asdtype(obj.indptr.dtype)
314-
self._format_class = get_csx_class(self._values_dtype, index_dtype, order)
313+
self._index_dtype = asdtype(obj.indptr.dtype)
314+
self._format_class = get_csx_class(self._values_dtype, self._index_dtype, order)
315315
self._obj = self._format_class.from_sps(obj)
316316
elif obj.format == "coo":
317-
index_dtype = asdtype(obj.coords[0].dtype)
318-
self._format_class = get_coo_class(self._values_dtype, index_dtype)
317+
self._index_dtype = asdtype(obj.coords[0].dtype)
318+
self._format_class = get_coo_class(self._values_dtype, self._index_dtype)
319319
self._obj = self._format_class.from_sps(obj)
320320
else:
321321
raise Exception(f"{obj.format} SciPy format not supported.")
322322

323323
elif _is_numpy_obj(obj):
324324
self._owns_memory = False
325-
index_dtype = asdtype(np.intp)
326-
self._format_class = get_dense_class(self._values_dtype, index_dtype)
325+
self._index_dtype = asdtype(np.intp)
326+
self._format_class = get_dense_class(self._values_dtype, self._index_dtype)
327327
self._obj = self._format_class.from_sps(obj)
328328

329329
elif _is_mlir_obj(obj):
@@ -334,8 +334,8 @@ def __init__(
334334
elif format is not None:
335335
if format == "csf":
336336
self._owns_memory = False
337-
index_dtype = asdtype(np.intp)
338-
self._format_class = get_csf_class(self._values_dtype, index_dtype)
337+
self._index_dtype = asdtype(np.intp)
338+
self._format_class = get_csf_class(self._values_dtype, self._index_dtype)
339339
self._obj = self._format_class.from_sps(obj)
340340
else:
341341
raise Exception(f"Format {format} not supported.")

sparse/mlir_backend/_ops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,32 @@ def reshape(a, shape):
9595
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
9696

9797

98+
@fn_cache
99+
def get_broadcast_to_module(
100+
in_tensor_type: ir.RankedTensorType,
101+
out_tensor_type: ir.RankedTensorType,
102+
dimensions: tuple[int, ...],
103+
) -> ir.Module:
104+
with ir.Location.unknown(ctx):
105+
module = ir.Module.create()
106+
107+
with ir.InsertionPoint(module.body):
108+
109+
@func.FuncOp.from_py_func(in_tensor_type)
110+
def broadcast_to(in_tensor):
111+
out = tensor.empty(out_tensor_type, [])
112+
return linalg.broadcast(in_tensor, outs=[out], dimensions=dimensions)
113+
114+
broadcast_to.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
115+
if DEBUG:
116+
(CWD / "broadcast_to_module.mlir").write_text(str(module))
117+
pm.run(module.operation)
118+
if DEBUG:
119+
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
120+
121+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
122+
123+
98124
def add(x1: Tensor, x2: Tensor) -> Tensor:
99125
ret_obj = x1._format_class()
100126
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)
@@ -135,3 +161,32 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
135161
)
136162

137163
return Tensor(ret_obj, shape=out_tensor_type.shape)
164+
165+
166+
def _infer_format_class(rank: int, values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
167+
from ._constructors import get_csf_class, get_csx_class, get_dense_class
168+
169+
if rank == 1:
170+
return get_dense_class(values_dtype, index_dtype)
171+
if rank == 2:
172+
return get_csx_class(values_dtype, index_dtype, order="r")
173+
if rank == 3:
174+
return get_csf_class(values_dtype, index_dtype)
175+
raise Exception(f"Rank not supported to infer format: {rank}")
176+
177+
178+
def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) -> Tensor:
179+
x_tensor_type = x._obj.get_tensor_definition(x.shape)
180+
format_class = _infer_format_class(len(shape), x._values_dtype, x._index_dtype)
181+
out_tensor_type = format_class.get_tensor_definition(shape)
182+
ret_obj = format_class()
183+
184+
broadcast_to_module = get_broadcast_to_module(x_tensor_type, out_tensor_type, tuple(dimensions))
185+
186+
broadcast_to_module.invoke(
187+
"broadcast_to",
188+
ctypes.pointer(ctypes.pointer(ret_obj)),
189+
*x._obj.to_module_arg(),
190+
)
191+
192+
return Tensor(ret_obj, shape=shape)

sparse/mlir_backend/tests/test_simple.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ def test_reshape(rng, dtype):
217217
arr = sps.random_array(
218218
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
219219
)
220-
if format == "coo":
221-
arr.sum_duplicates()
220+
arr.sum_duplicates()
222221

223222
tensor = sparse.asarray(arr)
224223

@@ -262,5 +261,55 @@ def test_reshape(rng, dtype):
262261
np.testing.assert_array_equal(actual, expected)
263262

264263
# DENSE
265-
# NOTE: dense reshape is probably broken in MLIR
264+
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
266265
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
266+
267+
268+
@parametrize_dtypes
269+
def test_broadcast_to(dtype):
270+
# CSR, CSC, COO
271+
for shape, new_shape, dimensions, input_arr, expected_arrs in [
272+
(
273+
(3, 4),
274+
(2, 3, 4),
275+
[0],
276+
np.array([[0, 1, 0, 3], [0, 0, 4, 5], [6, 7, 0, 0]]),
277+
[
278+
np.array([0, 3, 6]),
279+
np.array([0, 1, 2, 0, 1, 2]),
280+
np.array([0, 2, 4, 6, 8, 10, 12]),
281+
np.array([1, 3, 2, 3, 0, 1, 1, 3, 2, 3, 0, 1]),
282+
np.array([1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
283+
],
284+
),
285+
(
286+
(4, 2),
287+
(4, 2, 2),
288+
[1],
289+
np.array([[0, 1], [0, 0], [2, 3], [4, 0]]),
290+
[
291+
np.array([0, 2, 2, 4, 6]),
292+
np.array([0, 1, 0, 1, 0, 1]),
293+
np.array([0, 1, 2, 4, 6, 7, 8]),
294+
np.array([1, 1, 0, 1, 0, 1, 0, 0]),
295+
np.array([1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 4.0]),
296+
],
297+
),
298+
]:
299+
for fn_format in [sps.csr_array, sps.csc_array, sps.coo_array]:
300+
arr = fn_format(input_arr, shape=shape, dtype=dtype)
301+
arr.sum_duplicates()
302+
tensor = sparse.asarray(arr)
303+
result = sparse.broadcast_to(tensor, new_shape, dimensions=dimensions).to_scipy_sparse()
304+
305+
for actual, expected in zip(result, expected_arrs, strict=False):
306+
np.testing.assert_allclose(actual, expected)
307+
308+
# DENSE
309+
np_arr = np.array([0, 0, 2, 3, 0, 1])
310+
arr = np.asarray(np_arr, dtype=dtype)
311+
tensor = sparse.asarray(arr)
312+
result = sparse.broadcast_to(tensor, (3, 6), dimensions=[0]).to_scipy_sparse()
313+
314+
assert result.format == "csr"
315+
np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0))

0 commit comments

Comments
 (0)