Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 177 additions & 15 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,14 @@ def view(context, node):
x = inputs[0]
shape = inputs[1]

if np.prod(shape.shape) == 0:
# Reshape to empty shape (works only for scalar) is a no op
assert (
np.prod(x.shape) <= 1
), "Reshape to empty shape works only for scalar and single-element tensor"
context.add(mb.identity(x=x, name=node.name))
return

if isinstance(shape, ListVar):
length = mb.list_length(ls=shape)
indices = mb.range_1d(start=0, end=length, step=1)
Expand Down Expand Up @@ -3759,39 +3767,180 @@ def _internal_op_tensor_inplace_fill(context, node):


@register_torch_op
def index_put(context, node):
def select_scatter(context, node):
inputs = _get_inputs(context, node, expected=4)
x = inputs[0]
updates = inputs[1]
dim = inputs[2].val
if dim is None:
raise ValueError("Only compile time known dim supported yet")
index = inputs[3]

# mb.torch_tensor_assign handles multi-dim slicing
# so we need to create slice specifications for all other dimensions
begin = [0] * x.rank
begin[dim] = index
begin = mb.concat(values=begin, axis=0)
end = x.shape
# and squeeze dim to do pure indexing on it
squeeze_mask = [False] * x.rank
squeeze_mask[dim] = True

updated_x = _translate_torch_tensor_assign(
x=x,
updates=updates,
begin=begin,
end=end,
stride=None,
begin_mask=None,
end_mask=None,
squeeze_mask=squeeze_mask,
name=node.name,
)
context.add(updated_x)


@register_torch_op
def slice_scatter(context, node):
inputs = _get_inputs(context, node, min_expected=2)
x, updates = promote_input_dtypes(inputs[0:2])
dim = 0 if len(inputs) <= 2 else inputs[2].val
if dim is None:
raise ValueError("Only compile time known dim supported yet")
start = 0 if len(inputs) <= 3 else inputs[3]
end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim])
step = 1 if len(inputs) <= 5 else inputs[5]

assert dim is not None, "slice dim must be known at compile time"
assert 0 <= dim and dim < x.rank

# mb.torch_tensor_assign handles multi-dim slicing
# so we need to pad start, end, step from scalar to x.rank
starts = [0] * x.rank
starts[dim] = start
starts = mb.concat(values=starts, axis=0)
ends = list(x.shape)
ends[dim] = end
ends = mb.concat(values=ends, axis=0)
steps = [1] * x.rank
steps[dim] = step
steps = mb.concat(values=steps, axis=0)

updated_x = _translate_torch_tensor_assign(
x=x,
updates=updates,
begin=starts,
end=ends,
stride=steps,
begin_mask=None,
end_mask=None,
squeeze_mask=None,
name=node.name,
)
context.add(updated_x)


@register_torch_op
def index_put(context, node):
inputs = _get_inputs(context, node, min_expected=3)
x = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3].val
rank = x.rank
accumulate = False if len(inputs) < 4 else inputs[3].val
mode = "add" if accumulate else "update"

indices_type = indices[0].sym_type.get_primitive()
assert isinstance(indices, list), "indices must be a list of tensors"
# Usually indices is a list of non-None tensors, so we stack them and feed to mb.scatter_nd
# However, when there exists a whole slice (i.e. :), that index is represented as None
if any(map(lambda index: index is None, indices)):
# We have 2 ways to translate such torch.index_put, both have pros and cons
# 1. mb.scatter_nd
# * pro: can handle accumulate or update
# * con: can only have whole slice at last dimensions
# 2. mb.torch_tensor_assign
# * pro: can have whole slice at arbitrary dimension
# * con: can only handle update
# Here we use mb.torch_tensor_assign
# TODO: explore how can we cover as many torch.index_put cases as possible
if accumulate:
raise NotImplementedError(
"If there existed any whole slice (e.g. : in x[:, 0]), "
"only torch.index_put(..., accumulate=False) handled yet"
)

begin = [0] * x.rank
end = list(x.shape)
stride = [1] * x.rank
begin_mask = [True] * x.rank
end_mask = [True] * x.rank
# note: in torch slice, an indexed dim becomes size 1, rather than squeezed, e.g.
# x = torch.zeros((2, 3))
# y = x[:, 1]
# we will get y.shape as (2, 1)
is_dim_unity = [False] * x.rank
for dim, index in enumerate(indices):
if index is not None:
if len(index.shape) > 0:
index = mb.squeeze(x=index)
begin[dim] = index
end[dim] = mb.add(x=index, y=1)
begin_mask[dim] = False
end_mask[dim] = False
is_dim_unity[dim] = True
begin = mb.concat(values=begin, axis=0)
end = mb.concat(values=end, axis=0)

expected_values_shape = []
for dim in range(x.rank):
expected_values_shape.append(1 if is_dim_unity[dim] else x.shape[dim])
expected_values_shape = tuple(expected_values_shape)

if values.shape != expected_values_shape:
values = _broadcast(values.name + "_broadcasted", values, expected_values_shape)

updated_x = _translate_torch_tensor_assign(
x=x,
updates=values,
begin=begin,
end=end,
stride=stride,
begin_mask=begin_mask,
end_mask=end_mask,
squeeze_mask=[False] * x.rank,
name=node.name,
)
context.add(updated_x)
return

indices_type = indices[0].sym_type.get_primitive()
if types.is_bool(indices_type):
# indices
assert len(indices) == 1, "Unsupported index_put_ usage."
indices = indices[0]
assert (
indices.shape == x.shape
), "indices shape must equal to input shape for index put operation."
indices = mb.cast(x=indices, dtype="int32")
indices = mb.non_zero(x=indices)

if types.is_int(indices_type):
# values
if values.shape == ():
values = mb.expand_dims(x=values, axes=[0])
if values.rank == 1 and values.shape[0] == 1:
reps = value_at(mb.shape(x=indices), 0)
reps = mb.expand_dims(x=reps, axes=[0])
values = mb.tile(x=values, reps=reps)
elif types.is_int(indices_type):
# indices
if len(indices) > 1:
indices = mb.stack(values=indices, axis=indices[0].rank)
else:
indices = mb.expand_dims(x=indices[0], axes=[-1])

if len(values.shape) == 0:
values = mb.expand_dims(x=values, axes=[0])

if values.rank == 1 and values.shape[0] == 1:
reps = value_at(mb.shape(x=indices), 0)
reps = mb.expand_dims(x=reps, axes=[0])
values = mb.tile(x=values, reps=reps)
# values
expected_values_shape = indices.shape[:-1] + x.shape[indices.shape[-1] :]
if values.shape != expected_values_shape:
values = _broadcast(values.name + "_broadcasted", values, expected_values_shape)
else:
raise ValueError(f"Only bool and int index handled yet, but got {indices_type}")

if is_current_opset_version_compatible_with(target.iOS17):
# IOS17 `scatter_nd` behaviour is undefined for negative indices.
Expand Down Expand Up @@ -5598,7 +5747,20 @@ def std(context, node):
@register_torch_op
def copy(context, node):
inputs = _get_inputs(context, node, expected=[2, 3])
context.add(mb.identity(x=inputs[0], name=node.name))
assert (
context.frontend != TorchFrontend.TORCHSCRIPT
), (
"In torch script frontend, by graph pass `generate_tensor_assignment_ops`, "
"`torch.copy_` should have been replaced with `_internal_op_tensor_inplace_copy`"
)
if context.frontend == TorchFrontend.EXIR:
src = inputs[1]
if inputs[0].shape != src.shape:
_, src = _broadcast_tensors(inputs[: 2])
result = mb.identity(x=src, name=node.name)
else:
raise ValueError(f"Invalid PyTorch frontend {context.frontend}")
context.add(result)


@register_torch_op
Expand Down
Loading