|
22 | 22 | from pytensor.scalar import int32 as int_t
|
23 | 23 | from pytensor.scalar import upcast
|
24 | 24 | from pytensor.tensor import basic as at
|
25 |
| -from pytensor.tensor.basic import get_vector_length, second |
| 25 | +from pytensor.tensor.basic import alloc, second |
26 | 26 | from pytensor.tensor.exceptions import NotScalarConstantError
|
27 | 27 | from pytensor.tensor.math import abs as at_abs
|
28 | 28 | from pytensor.tensor.math import all as pt_all
|
@@ -1561,141 +1561,6 @@ def broadcast_shape_iter(
|
1561 | 1561 | return tuple(result_dims)
|
1562 | 1562 |
|
1563 | 1563 |
|
1564 |
| -class BroadcastTo(COp): |
1565 |
| - """An `Op` for `numpy.broadcast_to`.""" |
1566 |
| - |
1567 |
| - _output_type_depends_on_input_value = True |
1568 |
| - |
1569 |
| - __props__ = () |
1570 |
| - |
1571 |
| - view_map = {0: [0]} |
1572 |
| - |
1573 |
| - def __call__(self, a, shape, **kwargs): |
1574 |
| - return super().__call__(a, *shape, **kwargs) |
1575 |
| - |
1576 |
| - def make_node(self, a, *shape): |
1577 |
| - a = at.as_tensor_variable(a) |
1578 |
| - |
1579 |
| - shape, static_shape = at.infer_static_shape(shape) |
1580 |
| - |
1581 |
| - if len(shape) < a.ndim: |
1582 |
| - raise ValueError( |
1583 |
| - f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims" |
1584 |
| - ) |
1585 |
| - |
1586 |
| - out = TensorType(dtype=a.type.dtype, shape=static_shape)() |
1587 |
| - |
1588 |
| - # Attempt to prevent in-place operations on this view-based output |
1589 |
| - out.tag.indestructible = True |
1590 |
| - |
1591 |
| - return Apply(self, [a] + shape, [out]) |
1592 |
| - |
1593 |
| - def perform(self, node, inputs, output_storage): |
1594 |
| - a, *shape = inputs |
1595 |
| - z = output_storage[0] |
1596 |
| - z[0] = np.broadcast_to(a, shape) |
1597 |
| - |
1598 |
| - def grad(self, inputs, outputs_gradients): |
1599 |
| - a, *shape = inputs |
1600 |
| - (dout,) = outputs_gradients |
1601 |
| - |
1602 |
| - # Determine the dimensions that were added by broadcasting |
1603 |
| - new_dims = list(range(dout.ndim - a.ndim)) |
1604 |
| - |
1605 |
| - d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) |
1606 |
| - |
1607 |
| - # Determine the dimensions that were broadcast |
1608 |
| - _, static_shape = at.infer_static_shape(shape) |
1609 |
| - |
1610 |
| - # TODO: This needs to be performed at run-time when static shape |
1611 |
| - # information isn't available. |
1612 |
| - bcast_sums = [ |
1613 |
| - i |
1614 |
| - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) |
1615 |
| - if a_s == 1 and s_s != 1 |
1616 |
| - ] |
1617 |
| - |
1618 |
| - if bcast_sums: |
1619 |
| - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) |
1620 |
| - |
1621 |
| - return [d_wrt_a] + [ |
1622 |
| - grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) |
1623 |
| - ] |
1624 |
| - |
1625 |
| - def infer_shape(self, fgraph, node, ins_shapes): |
1626 |
| - return [node.inputs[1:]] |
1627 |
| - |
1628 |
| - def c_code(self, node, name, inputs, outputs, sub): |
1629 |
| - inp_dims = node.inputs[0].ndim |
1630 |
| - out_dims = node.outputs[0].ndim |
1631 |
| - new_dims = out_dims - inp_dims |
1632 |
| - |
1633 |
| - (x, *shape) = inputs |
1634 |
| - (out,) = outputs |
1635 |
| - fail = sub["fail"] |
1636 |
| - |
1637 |
| - # TODO: Could just use `PyArray_Return`, no? |
1638 |
| - dims_array = ", ".join( |
1639 |
| - [ |
1640 |
| - f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" |
1641 |
| - for i, shape in enumerate(shape) |
1642 |
| - ] |
1643 |
| - ) |
1644 |
| - |
1645 |
| - src = ( |
1646 |
| - """ |
1647 |
| - npy_intp itershape[%(out_dims)s] = {%(dims_array)s}; |
1648 |
| -
|
1649 |
| - NpyIter *iter; |
1650 |
| - PyArrayObject *ops[1] = {%(x)s}; |
1651 |
| - npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; |
1652 |
| - npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; |
1653 |
| - PyArray_Descr *op_dtypes[1] = {NULL}; |
1654 |
| - int oa_ndim = %(out_dims)s; |
1655 |
| - int* op_axes[1] = {NULL}; |
1656 |
| - npy_intp buffersize = 0; |
1657 |
| -
|
1658 |
| - for(int i = 0; i < %(inp_dims)s; i++) |
1659 |
| - { |
1660 |
| - if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s])) |
1661 |
| - { |
1662 |
| - PyErr_Format(PyExc_ValueError, |
1663 |
| - "Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.", |
1664 |
| - i, |
1665 |
| - (long long int) itershape[i + %(new_dims)s], |
1666 |
| - (long long int) PyArray_DIMS(%(x)s)[i] |
1667 |
| - ); |
1668 |
| - %(fail)s |
1669 |
| - } |
1670 |
| - } |
1671 |
| -
|
1672 |
| - iter = NpyIter_AdvancedNew( |
1673 |
| - 1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize |
1674 |
| - ); |
1675 |
| - %(out)s = NpyIter_GetIterView(iter, 0); |
1676 |
| -
|
1677 |
| - if(%(out)s == NULL){ |
1678 |
| - NpyIter_Deallocate(iter); |
1679 |
| - %(fail)s; |
1680 |
| - } |
1681 |
| -
|
1682 |
| - if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { |
1683 |
| - %(fail)s; |
1684 |
| - } |
1685 |
| -
|
1686 |
| - """ |
1687 |
| - % locals() |
1688 |
| - ) |
1689 |
| - |
1690 |
| - return src |
1691 |
| - |
1692 |
| - def c_code_cache_version(self): |
1693 |
| - return (2,) |
1694 |
| - |
1695 |
| - |
1696 |
| -broadcast_to_ = BroadcastTo() |
1697 |
| - |
1698 |
| - |
1699 | 1564 | def geomspace(start, end, steps, base=10.0):
|
1700 | 1565 | from pytensor.tensor.math import log
|
1701 | 1566 |
|
@@ -1739,13 +1604,7 @@ def broadcast_to(
|
1739 | 1604 | broadcasted array may refer to a single memory location.
|
1740 | 1605 |
|
1741 | 1606 | """
|
1742 |
| - x = at.as_tensor(x) |
1743 |
| - shape_len = get_vector_length(shape) |
1744 |
| - |
1745 |
| - if x.ndim == 0 and shape_len == 0: |
1746 |
| - return x |
1747 |
| - |
1748 |
| - return broadcast_to_(x, shape) |
| 1607 | + return alloc(x, *shape) |
1749 | 1608 |
|
1750 | 1609 |
|
1751 | 1610 | def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
|
|
0 commit comments