From 8c0db3fb2eff75a3a7f4e348e9d2f1f7991b4655 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Fri, 5 Apr 2024 11:15:09 -0700 Subject: [PATCH 1/6] polish mb.scatter_nd doc: fix some errors, de-duplicate symbols, add an example to make it clearer --- .../mil/mil/ops/defs/iOS15/scatter_gather.py | 23 +++++++++++++------ .../ops/tests/iOS14/test_scatter_gather.py | 3 +++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py b/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py index 216bb3e19..671685fbc 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py @@ -486,26 +486,35 @@ class scatter_nd(Operation): Scatter ``updates`` to ``data`` at locations ``indices``. The ``indices`` is a K-dim tensor, where ``indices[i_0,...,i_{K-2}]`` defines a - slice of ``data``, ``K = rank(indices)``, and ``data[indices[i_0, ..., i_{K-2}]]`` - has rank ``rank(data) - indices.shape[-1]``. + slice of ``data``, ``K = rank(indices)``, and ``data[i_0, ..., i_{K-2}]`` + has rank ``rank(data) - indices.shape[-1]``. Concretely, this means the index is + stored in the last dim of ``indices``, e.g. take a ``K == 2`` example + + .. math:: + indices = [[0, 1], + [0, 2]] + + where ``indices[0]`` / ``[0, 1]`` and ``indices[1]`` / ``[0, 2]`` are + two indices that get applied to ``data`` * Example: ``mode == update``: The ``output`` is set to ``data`` initially, and the op updates ``output`` as follows: .. math:: - output[indices[i_0, ..., i_{K-2}]]= updates[indices[i_0, ..., i_{K-2}]] + output[indices[i_0, ..., i_{K-2}]]= updates[i_0, ..., i_{K-2}] * Example: ``mode == add``. The update rule is: .. math:: - output[indices[i_0, ..., i_{K-2}]] += updates[indices[i_0, ..., i_{K-2}]] + output[indices[i_0, ..., i_{K-2}]] += updates[i_0, ..., i_{K-2}] Parameters ---------- data: tensor<\*D, T> (Required) - indices: tensor<\*K, i32> (Required) - updates: tensor<\*K, T> (Required) - * Must be the shape as ``K[:-1]+data.shape[K[-1]:]``. + indices: tensor<\*E, i32> (Required) + * E[-1] <= len(D) + updates: tensor<\*E, T> (Required) + * Must be the shape as ``E[:-1] + data.shape[E[-1]:]``. mode: const string (Optional) * Default to ``add``. * Can be the following modes: ``update``, ``add``, ``sub``, ``mul``, diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_scatter_gather.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_scatter_gather.py index 7c3dd2913..c826653e3 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_scatter_gather.py @@ -122,6 +122,7 @@ def build(data, indices, updates): backend=backend, ) + class TestScatterAlongAxis: @pytest.mark.parametrize( "compute_unit, backend", @@ -346,6 +347,7 @@ def build(data, indices, updates): backend=backend, ) + class TestGather: @pytest.mark.parametrize( "compute_unit, backend", @@ -514,6 +516,7 @@ def prog(x): return x + class TestGatherAlongAxis: @pytest.mark.parametrize( "compute_unit, backend, x_dtype, indices_dtype", From d60574c0056622167814109bfc2a556b1b89a752 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Fri, 5 Apr 2024 12:15:54 -0700 Subject: [PATCH 2/6] support several indexing ops: add select_scatter, add slice_scatter, improve index_put; polish copy and reshape ops along the way --- .../converters/mil/frontend/torch/ops.py | 196 +++++++- .../mil/frontend/torch/test/test_torch_ops.py | 460 ++++++++++++++---- 2 files changed, 556 insertions(+), 100 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 7d54dd9a9..6b0f84260 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1676,6 +1676,12 @@ def view(context, node): x = inputs[0] shape = inputs[1] + if np.prod(shape.shape) == 0: + # Reshape to empty shape (works only for scalar) is a no op + assert np.prod(x.shape) <= 1, "Reshape to empty shape works only for scalar" + context.add(mb.identity(x=x, name=node.name)) + return + if isinstance(shape, ListVar): length = mb.list_length(ls=shape) indices = mb.range_1d(start=0, end=length, step=1) @@ -3759,18 +3765,164 @@ def _internal_op_tensor_inplace_fill(context, node): @register_torch_op -def index_put(context, node): +def select_scatter(context, node): inputs = _get_inputs(context, node, expected=4) x = inputs[0] + updates = inputs[1] + dim = inputs[2].val + index = inputs[3] + + # mb.torch_tensor_assign handles multi-dim slicing + # so we need to create slice specifications for all other dimensions + begin = [0] * x.rank + begin[dim] = index + begin = mb.concat(values=begin, axis=0) + end = x.shape + stride = [1] * x.rank + + begin_mask = [True] * x.rank + if index.val not in (0, -x.rank): + begin_mask[dim] = False + end_mask = [True] * x.rank + squeeze_mask = [False] * x.rank + squeeze_mask[dim] = True + + updated_x = _translate_torch_tensor_assign( + x=x, + updates=updates, + begin=begin, + end=end, + stride=stride, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + name=node.name, + ) + context.add(updated_x) + + +@register_torch_op +def slice_scatter(context, node): + inputs = _get_inputs(context, node, min_expected=2) + x = inputs[0] + updates = inputs[1] + dim = 0 if len(inputs) <= 2 else inputs[2].val + start = 0 if len(inputs) <= 3 else inputs[3] + end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim]) + step = 1 if len(inputs) <= 5 else inputs[5] + + assert dim is not None, "slice dim must be known at compile time" + assert 0 <= dim and dim < x.rank + + # mb.torch_tensor_assign handles multi-dim slicing + # so we need to pad start, end, step from scalar to x.rank + starts = [0] * x.rank + starts[dim] = start + starts = mb.concat(values=starts, axis=0) + ends = list(x.shape) + ends[dim] = end + ends = mb.concat(values=ends, axis=0) + steps = [1] * x.rank + steps[dim] = step + steps = mb.concat(values=steps, axis=0) + + # mb.torch_tensor_assign also have masks + begin_mask = [True] * x.rank + if start.val not in (0, -x.rank): + begin_mask[dim] = False + end_mask = [True] * x.rank + if end.val is None or end.val < x.shape[dim]: + end_mask[dim] = False + squeeze_mask = [False] * x.rank + + updated_x = _translate_torch_tensor_assign( + x=x, + updates=updates, + begin=starts, + end=ends, + stride=steps, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + name=node.name, + ) + context.add(updated_x) + + +@register_torch_op +def index_put(context, node): + inputs = _get_inputs(context, node, min_expected=3) + x = inputs[0] indices = inputs[1] values = inputs[2] - accumulate = inputs[3].val - rank = x.rank + accumulate = False if len(inputs) < 4 else inputs[3].val mode = "add" if accumulate else "update" - indices_type = indices[0].sym_type.get_primitive() + assert isinstance(indices, list), "indices must be a list of tensors" + # Usually indices is a list of non-None tensors, so we stack them and feed to mb.scatter_nd + # However, when there exists a whole slice, that index is represented as None + exist_slice = False + for index in indices: + if index is None: + exist_slice = True + break + if exist_slice: + # We have 2 ways to translate such torch.index_put, both have pros and cons + # 1. mb.scatter_nd + # * pro: can handle accumulate or update + # * con: can only have whole slice at the endmost dimensions + # 2. mb.torch_tensor_assign + # * pro: can have whole slice at arbitrary dimension + # * con: can only handle update + # Here we use mb.torch_tensor_assign + # TODO: explore how can we cover as many torch.index_put cases as possible + if accumulate: + raise NotImplementedError("If there exists whole slices, only update mode handled yet") + + begin = [0] * x.rank + end = list(x.shape) + stride = [1] * x.rank + begin_mask = [True] * x.rank + end_mask = [True] * x.rank + # note: in torch slice, an indexed dim becomes size 1, rather than squeezed + is_dim_size1 = [False] * x.rank + for dim, index in enumerate(indices): + if index is not None: + if len(index.shape) > 0: + index = mb.squeeze(x=index) + begin[dim] = index + end[dim] = mb.add(x=index, y=1) + begin_mask[dim] = False + end_mask[dim] = False + is_dim_size1[dim] = True + begin = mb.concat(values=begin, axis=0) + end = mb.concat(values=end, axis=0) + + expected_values_shape = [] + for dim in range(x.rank): + expected_values_shape.append(1 if is_dim_size1[dim] else x.shape[dim]) + expected_values_shape = tuple(expected_values_shape) + + if values.shape != expected_values_shape: + values = _broadcast(values.name + "_broadcasted", values, expected_values_shape) + + updated_x = _translate_torch_tensor_assign( + x=x, + updates=values, + begin=begin, + end=end, + stride=stride, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=[False] * x.rank, + name=node.name, + ) + context.add(updated_x) + return + indices_type = indices[0].sym_type.get_primitive() if types.is_bool(indices_type): + # indices assert len(indices) == 1, "Unsupported index_put_ usage." indices = indices[0] assert ( @@ -3778,20 +3930,25 @@ def index_put(context, node): ), "indices shape must equal to input shape for index put operation." indices = mb.cast(x=indices, dtype="int32") indices = mb.non_zero(x=indices) - - if types.is_int(indices_type): + # values + if len(values.shape) == 0: + values = mb.expand_dims(x=values, axes=[0]) + if np.prod(values.shape) <= 1: + reps = value_at(mb.shape(x=indices), 0) + reps = mb.expand_dims(x=reps, axes=[0]) + values = mb.tile(x=values, reps=reps) + elif types.is_int(indices_type): + # indices if len(indices) > 1: indices = mb.stack(values=indices, axis=indices[0].rank) else: indices = mb.expand_dims(x=indices[0], axes=[-1]) - - if len(values.shape) == 0: - values = mb.expand_dims(x=values, axes=[0]) - - if values.rank == 1 and values.shape[0] == 1: - reps = value_at(mb.shape(x=indices), 0) - reps = mb.expand_dims(x=reps, axes=[0]) - values = mb.tile(x=values, reps=reps) + # values + expected_values_shape = indices.shape[:-1] + x.shape[indices.shape[-1] :] + if values.shape != expected_values_shape: + values = _broadcast(values.name + "_broadcasted", values, expected_values_shape) + else: + raise ValueError(f"Only bool and int index handled yet, but got {indices_type}") if is_current_opset_version_compatible_with(target.iOS17): # IOS17 `scatter_nd` behaviour is undefined for negative indices. @@ -5598,7 +5755,16 @@ def std(context, node): @register_torch_op def copy(context, node): inputs = _get_inputs(context, node, expected=[2, 3]) - context.add(mb.identity(x=inputs[0], name=node.name)) + if context.frontend == TorchFrontend.TORCHSCRIPT: + result = mb.identity(x=inputs[0], name=node.name) + elif context.frontend == TorchFrontend.EXIR: + src = inputs[1] + if inputs[0].shape != src.shape: + _, src = _broadcast_tensors(inputs[: 2]) + result = mb.identity(x=src, name=node.name) + else: + raise ValueError(f"Invalid PyTorch frontend {context.frontend}") + context.add(result) @register_torch_op diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index be68ca4a9..fb6878218 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -5090,10 +5090,11 @@ def test_cumsum(self, compute_unit, backend, axis): class TestReshape(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, output_shape, minimum_deployment_target", + "compute_unit, backend, frontend, output_shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, [ (3, 2), (2, -1), @@ -5102,12 +5103,35 @@ class TestReshape(TorchBaseTest): [None, ct.target.iOS17], ), ) - def test_reshape(self, compute_unit, backend, output_shape, minimum_deployment_target): + def test_reshape( + self, compute_unit, backend, frontend, output_shape, minimum_deployment_target + ): input_shape = (2, 3) model = ModuleWrapper(function=torch.reshape, kwargs={"shape": output_shape}) self.run_compare_torch( input_shape, model, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend, minimum_deployment_target", + itertools.product( + compute_units, + backends, + frontends, + [None, ct.target.iOS17], + ), + ) + def test_reshape_scalar(self, compute_unit, backend, frontend, minimum_deployment_target): + model = ModuleWrapper(function=torch.reshape, kwargs={"shape": ()}) + self.run_compare_torch( + (1,), + model, + frontend=frontend, backend=backend, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, @@ -6690,14 +6714,15 @@ def forward(self, x): class TestCopy(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, rank", + "compute_unit, backend, frontend, rank", itertools.product( compute_units, backends, + frontends, [1, 3], ), ) - def test_copy_(self, compute_unit, backend, rank): + def test_copy_(self, compute_unit, backend, frontend, rank): input_shape = np.random.randint(low=2, high=6, size=rank) input_shape = tuple(input_shape) @@ -6708,17 +6733,24 @@ def forward(self, x): return y model = CopyModel() - self.run_compare_torch(input_shape, model, backend=backend, compute_unit=compute_unit) + self.run_compare_torch( + input_shape, + model, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) @pytest.mark.parametrize( - "compute_unit, backend, rank", + "compute_unit, backend, frontend, rank", itertools.product( compute_units, backends, + frontends, [1, 3], ), ) - def test_copy__2(self, compute_unit, backend, rank): + def test_copy__2(self, compute_unit, backend, frontend, rank): input_shape = np.random.randint(low=2, high=6, size=rank) input_shape = tuple(input_shape) @@ -6729,7 +6761,13 @@ def forward(self, x): return y + 1 model = CopyModel() - self.run_compare_torch(input_shape, model, backend=backend, compute_unit=compute_unit) + self.run_compare_torch( + input_shape, + model, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) class TestZeros(TorchBaseTest): @@ -7360,10 +7398,11 @@ def forward(self, x): class TestSelect(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, dim_index", + "compute_unit, backend, frontend, dim_index", itertools.product( compute_units, backends, + frontends, [ [0, 0], [1, 1], @@ -7371,7 +7410,7 @@ class TestSelect(TorchBaseTest): ], ), ) - def test_select(self, compute_unit, backend, dim_index): + def test_select(self, compute_unit, backend, frontend, dim_index): dim, index = dim_index class SelectModel(nn.Module): @@ -7379,17 +7418,25 @@ def forward(self, x): return x.select(dim, index) input_shape = (1, 2, 3) - model = SelectModel() self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, + SelectModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, ) - @pytest.mark.parametrize( - "compute_unit, backend", - itertools.product(compute_units, backends) + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends) ) - def test_dynamic_index(self, compute_unit, backend): + def test_dynamic_index(self, compute_unit, backend, frontend): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2189: " + "torch.export Cannot Use Dynamic Index to Select" + ) + class M(torch.nn.Module): def forward(self, float_arr, int_arr): dynamic_index = int_arr[1] @@ -7408,10 +7455,31 @@ def forward(self, float_arr, int_arr): M(), input_as_shape=False, converter_input_type=inputs_types, + frontend=frontend, backend=backend, compute_unit=compute_unit ) + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_dynamic_index_with_explicit_slice_on_all_other_dims(self, compute_unit, backend, frontend): + class SelectModel(torch.nn.Module): + def forward(self, x, position): + y = x[:, :, position] + return y + + self.run_compare_torch( + [(2, 3, 4), (1,)], + SelectModel(), + input_dtype=np.int32, + rand_range=(0, 2), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + class TestNonZero(TorchBaseTest): @pytest.mark.parametrize( @@ -7832,14 +7900,21 @@ def forward(self, x): class TestIndexPut(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, minimum_deployment_target", + "compute_unit, backend, frontend, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, [None, ct.target.iOS17], ), ) - def test_index_put_case_1(self, compute_unit, backend, minimum_deployment_target): + def test_index_put_bool_index_case_1(self, compute_unit, backend, frontend, minimum_deployment_target): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2183: " + "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" + ) + class IndexPutModel(torch.nn.Module): def forward(self, x, y): y = x + 1 @@ -7851,21 +7926,31 @@ def forward(self, x, y): self.run_compare_torch( [shape, shape], IndexPutModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, ) @pytest.mark.parametrize( - "compute_unit, backend, rank, minimum_deployment_target", + "compute_unit, backend, frontend, rank, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, [0, 1], [None, ct.target.iOS17], ), ) - def test_index_put_case_2(self, compute_unit, backend, rank, minimum_deployment_target): + def test_index_put_bool_index_case_2( + self, compute_unit, backend, frontend, rank, minimum_deployment_target + ): + if backend[0] == "neuralnetwork" and frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2185: " + "EXIR IndexPut Fails on NeuralNetwork Backend" + ) + class IndexPutModel(torch.nn.Module): def forward(self, x): mask = torch.tensor([True, False, False, False, True, True]).view(3, 2) @@ -7875,20 +7960,30 @@ def forward(self, x): x[mask] = torch.tensor([1.0]) return x - shape = (3, 2) - model = IndexPutModel() - self.run_compare_torch(shape, model, backend=backend, compute_unit=compute_unit, - minimum_deployment_target=minimum_deployment_target) + self.run_compare_torch( + (3, 2), + IndexPutModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target + ) @pytest.mark.parametrize( - "compute_unit, backend, minimum_deployment_target", + "compute_unit, backend, frontend, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, [None, ct.target.iOS17], ), ) - def test_index_put_case_3(self, compute_unit, backend, minimum_deployment_target): + def test_index_put_dynamic_bool_index(self, compute_unit, backend, frontend, minimum_deployment_target): + if backend[0] == "neuralnetwork" and frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2185: " + "EXIR IndexPut Fails on NeuralNetwork Backend" + ) if _macos_version() < (13, 0): pytest.skip("Issue fixed in iOS16/macOS13") @@ -7902,10 +7997,10 @@ def forward(self, x, y): torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6]), torch.Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ] - model = IndexPutModel() self.run_compare_torch( inputs, - model, + IndexPutModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, @@ -7913,10 +8008,25 @@ def forward(self, x, y): ) @pytest.mark.parametrize( - "compute_unit, backend, rank, accumulate, minimum_deployment_target", - itertools.product(compute_units, backends, [3], [True, False], [None, ct.target.iOS17]), + "compute_unit, backend, frontend, rank, accumulate, minimum_deployment_target", + itertools.product( + compute_units, + backends, + frontends, + [3], + [True, False], + [None, ct.target.iOS17], + ), ) - def test_index_put_case_4(self, compute_unit, backend, rank, accumulate, minimum_deployment_target): + def test_index_put_int_index_case_1( + self, compute_unit, backend, frontend, rank, accumulate, minimum_deployment_target + ): + if backend[0] == "neuralnetwork" and frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2185: " + "EXIR IndexPut Fails on NeuralNetwork Backend" + ) + class IndexPutModel(torch.nn.Module): def forward(self, x, indices, values): x.index_put_(tuple(indices.t()), values, accumulate=accumulate) @@ -7945,6 +8055,7 @@ def forward(self, x, indices, values): self.run_compare_torch( inputs, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, @@ -7952,12 +8063,117 @@ def forward(self, x, indices, values): ) @pytest.mark.parametrize( - "compute_unit, backend, accumulate, minimum_deployment_target", - itertools.product(compute_units, backends, [True, False], [None, ct.target.iOS17]), + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_index_put_int_index_case_2(self, compute_unit, backend, frontend): + class IndexPutModel(torch.nn.Module): + def forward(self, x): + box_corner = x.new(x.shape) + box_corner[:, :, 0] = x[:, :, 0] + box_corner[:, :, 1] = x[:, :, 1] + return box_corner[:, :, :2] + + self.run_compare_torch( + (2, 3, 4), + IndexPutModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_index_put_int_index_case_3(self, compute_unit, backend, frontend): + class IndexPutModel(torch.nn.Module): + def forward(self, x): + y = x.clone() + y[:, 0] = 1.0 + return y + + self.run_compare_torch( + (2, 3), + IndexPutModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend, val_shape", + itertools.product(compute_units, backends, frontends, ((2, 1), (1,))), + ) + def test_index_put_dynamic_int_index_case_1(self, compute_unit, backend, frontend, val_shape): + if frontend == TorchFrontend.TORCHSCRIPT: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2188: " + "torch.jit.trace Inplace Index Put Silent Error" + ) + + class IndexPutModel(torch.nn.Module): + def forward(self, x, position, val): + y = x.clone() + y[:, position] = val + return y + + self.run_compare_torch( + [(2, 3), (1,), val_shape], + IndexPutModel(), + input_dtype=np.int32, + rand_range=(0, 2), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_index_put_dynamic_int_index_case_2(self, compute_unit, backend, frontend): + if frontend == TorchFrontend.TORCHSCRIPT: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2188: " + "torch.jit.trace Inplace Index Put Silent Error" + ) + + class IndexPutModel(torch.nn.Module): + def forward(self, x, position, val): + y = x.clone() + y[position, 1:4] = val + return y + + self.run_compare_torch( + [(2, 4), (1,), (1,)], + IndexPutModel(), + input_dtype=np.int32, + rand_range=(0, 2), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend, accumulate, minimum_deployment_target", + itertools.product( + compute_units, + backends, + frontends, + [True, False], + [None, ct.target.iOS17], + ), ) def test_index_put_negative_indices_case_1( - self, compute_unit, backend, accumulate, minimum_deployment_target + self, compute_unit, backend, frontend, accumulate, minimum_deployment_target ): + if backend[0] == "neuralnetwork" and frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2185: " + "EXIR IndexPut Fails on NeuralNetwork Backend" + ) + class IndexPutModel(torch.nn.Module): def forward(self, x): x.index_put_( @@ -7970,20 +8186,32 @@ def forward(self, x): self.run_compare_torch( (3, 4), IndexPutModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, ) @pytest.mark.parametrize( - "compute_unit, backend, rank, accumulate, minimum_deployment_target", + "compute_unit, backend, frontend, rank, accumulate, minimum_deployment_target", itertools.product( - compute_units, backends, [1, 2, 3], [True, False], [None, ct.target.iOS17] + compute_units, + backends, + frontends, + [1, 2, 3], + [True, False], + [None, ct.target.iOS17], ), ) def test_index_put_negative_indices_case_2( - self, compute_unit, backend, rank, accumulate, minimum_deployment_target + self, compute_unit, backend, frontend, rank, accumulate, minimum_deployment_target ): + if backend[0] == "neuralnetwork" and frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2185: " + "EXIR IndexPut Fails on NeuralNetwork Backend" + ) + class IndexPutModel(torch.nn.Module): def forward(self, x, indices, values): x.index_put_(tuple(indices.t()), values, accumulate=accumulate) @@ -8012,38 +8240,21 @@ def forward(self, x, indices, values): self.run_compare_torch( inputs, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, minimum_deployment_target=minimum_deployment_target, ) - @pytest.mark.parametrize( - "compute_unit, backend", - itertools.product(compute_units, backends), - ) - def test_index_put_case_5(self, compute_unit, backend): - class IndexPutModel(torch.nn.Module): - def forward(self, x): - box_corner = x.new(x.shape) - box_corner[:, :, 0] = x[:, :, 0] - box_corner[:, :, 1] = x[:, :, 1] - return box_corner[:, :, :2] - - self.run_compare_torch( - (2, 3, 4), - IndexPutModel(), - backend=backend, - compute_unit=compute_unit, - ) - class TestIndex(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (10,), @@ -8052,8 +8263,15 @@ class TestIndex(TorchBaseTest): [None, ct.target.iOS17], ), ) - def test_index_bool_indices(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): - rank = len(shape) + def test_index_bool_indices( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2183: " + "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" + ) + class IndexModel(torch.nn.Module): def __init__(self, axis): super().__init__() @@ -8071,6 +8289,7 @@ def forward(self, x, y): assert self.axis == 3 return x[:, :, :, index] + rank = len(shape) for index_rank in range(1, rank + 1): for axis in range(rank + 1 - index_rank): input_data = generate_input_data(shape, rand_range=(0, 2), dtype=input_dtype) @@ -8084,6 +8303,7 @@ def forward(self, x, y): self.run_compare_torch( [input_data, ref_data], model, + frontend=frontend, backend=backend, compute_unit=compute_unit, input_as_shape=False, @@ -8091,10 +8311,11 @@ def forward(self, x, y): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2), @@ -8103,7 +8324,15 @@ def forward(self, x, y): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_1(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_1( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2184: " + "Cannot Convert Empty EXIR Model" + ) + # all elements are selected class IndexModel(torch.nn.Module): def forward(self, x): @@ -8116,6 +8345,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8124,10 +8354,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2), @@ -8136,7 +8367,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_2(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_2( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Only one axis is sliced.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8151,6 +8384,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8159,10 +8393,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8171,7 +8406,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_3(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_3( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Only two axes are sliced, and connected.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8189,6 +8426,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8197,10 +8435,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8209,7 +8448,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_4(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_4( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Only two axes are sliced, and not connected.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8227,6 +8468,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8235,10 +8477,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8247,7 +8490,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_5(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_5( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """All axes are sliced.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8268,6 +8513,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8276,10 +8522,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2), @@ -8288,7 +8535,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_6(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_6( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Only one axis is sliced + nd mode.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8305,6 +8554,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8313,10 +8563,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8325,7 +8576,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_7(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_7( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Two axes are sliced, and connected + nd mode.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8343,6 +8596,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8351,10 +8605,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8363,7 +8618,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_8(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_8( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Two axes are sliced, and not connected + nd mode.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8381,6 +8638,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8389,10 +8647,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8401,7 +8660,15 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_9(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_9( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2183: " + "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" + ) + """One axis is sliced through bool mask.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8415,6 +8682,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8423,10 +8691,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8435,7 +8704,15 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_10(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_10( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2183: " + "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" + ) + """Multiple axes are sliced through bool masks with possible broadcasting.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8464,6 +8741,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8472,10 +8750,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (3, 4), @@ -8484,7 +8763,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_11(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_11( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Broadcastable indices.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8503,6 +8784,7 @@ def forward(self, x): self.run_compare_torch( shape, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8511,10 +8793,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8523,7 +8806,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_12(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_12( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Another broadcastable indices test case.""" class IndexModel(torch.nn.Module): def forward(self, x): @@ -8538,6 +8823,7 @@ def forward(self, x): self.run_compare_torch( shape, IndexModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), @@ -8546,10 +8832,11 @@ def forward(self, x): ) @pytest.mark.parametrize( - "compute_unit, backend, input_dtype, shape, minimum_deployment_target", + "compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target", itertools.product( compute_units, backends, + frontends, (np.float32, np.int32, np.bool_), [ (1, 2, 3), @@ -8558,7 +8845,9 @@ def forward(self, x): [None, ct.target.iOS17], ), ) - def test_index_int_index_case_13(self, compute_unit, backend, input_dtype, shape, minimum_deployment_target): + def test_index_int_index_case_13( + self, compute_unit, backend, frontend, input_dtype, shape, minimum_deployment_target + ): """Another broadcastable indices (negative) test case.""" class IndexModel(torch.nn.Module): @@ -8570,6 +8859,7 @@ def forward(self, x): self.run_compare_torch( shape, IndexModel(), + frontend=frontend, backend=backend, compute_unit=compute_unit, rand_range=(0, 2), From 8422fcb6d727c7e7d62ae2fb79b3126062b7a06f Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Mon, 8 Apr 2024 17:28:48 -0700 Subject: [PATCH 3/6] address review comment on scatter_nd doc: indices and updates have different shapes; make doc more explicit --- .../converters/mil/mil/ops/defs/iOS15/scatter_gather.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py b/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py index 671685fbc..788a2e418 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py @@ -512,9 +512,9 @@ class scatter_nd(Operation): ---------- data: tensor<\*D, T> (Required) indices: tensor<\*E, i32> (Required) - * E[-1] <= len(D) - updates: tensor<\*E, T> (Required) - * Must be the shape as ``E[:-1] + data.shape[E[-1]:]``. + * indices.shape[-1] <= data.rank + updates: tensor<\*F, T> (Required) + * Must be the shape as ``indices.shape[:-1] + data.shape[indices.shape[-1]:]``. mode: const string (Optional) * Default to ``add``. * Can be the following modes: ``update``, ``add``, ``sub``, ``mul``, From 2849faf6d445304df4bd9eed38758bb08a742c55 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Mon, 8 Apr 2024 17:54:51 -0700 Subject: [PATCH 4/6] address review comments: simplify code with any(map(lambda; clarify error messages and docs, mostly by adding examples --- .../converters/mil/frontend/torch/ops.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 6b0f84260..cf4912fc2 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1678,7 +1678,9 @@ def view(context, node): if np.prod(shape.shape) == 0: # Reshape to empty shape (works only for scalar) is a no op - assert np.prod(x.shape) <= 1, "Reshape to empty shape works only for scalar" + assert ( + np.prod(x.shape) <= 1 + ), "Reshape to empty shape works only for scalar and single-element tensor" context.add(mb.identity(x=x, name=node.name)) return @@ -3770,6 +3772,8 @@ def select_scatter(context, node): x = inputs[0] updates = inputs[1] dim = inputs[2].val + if dim is None: + raise ValueError("Only compile time known dim supported yet") index = inputs[3] # mb.torch_tensor_assign handles multi-dim slicing @@ -3860,32 +3864,33 @@ def index_put(context, node): assert isinstance(indices, list), "indices must be a list of tensors" # Usually indices is a list of non-None tensors, so we stack them and feed to mb.scatter_nd - # However, when there exists a whole slice, that index is represented as None - exist_slice = False - for index in indices: - if index is None: - exist_slice = True - break - if exist_slice: + # However, when there exists a whole slice (i.e. :), that index is represented as None + if any(map(lambda index: index is None, indices)): # We have 2 ways to translate such torch.index_put, both have pros and cons # 1. mb.scatter_nd # * pro: can handle accumulate or update - # * con: can only have whole slice at the endmost dimensions + # * con: can only have whole slice at last dimensions # 2. mb.torch_tensor_assign # * pro: can have whole slice at arbitrary dimension # * con: can only handle update # Here we use mb.torch_tensor_assign # TODO: explore how can we cover as many torch.index_put cases as possible if accumulate: - raise NotImplementedError("If there exists whole slices, only update mode handled yet") + raise NotImplementedError( + "If there existed any whole slice (e.g. : in x[:, 0]), " + "only torch.index_put(..., accumulate=False) handled yet" + ) begin = [0] * x.rank end = list(x.shape) stride = [1] * x.rank begin_mask = [True] * x.rank end_mask = [True] * x.rank - # note: in torch slice, an indexed dim becomes size 1, rather than squeezed - is_dim_size1 = [False] * x.rank + # note: in torch slice, an indexed dim becomes size 1, rather than squeezed, e.g. + # x = torch.zeros((2, 3)) + # y = x[:, 1] + # we will get y.shape as (2, 1) + is_dim_unity = [False] * x.rank for dim, index in enumerate(indices): if index is not None: if len(index.shape) > 0: @@ -3894,13 +3899,13 @@ def index_put(context, node): end[dim] = mb.add(x=index, y=1) begin_mask[dim] = False end_mask[dim] = False - is_dim_size1[dim] = True + is_dim_unity[dim] = True begin = mb.concat(values=begin, axis=0) end = mb.concat(values=end, axis=0) expected_values_shape = [] for dim in range(x.rank): - expected_values_shape.append(1 if is_dim_size1[dim] else x.shape[dim]) + expected_values_shape.append(1 if is_dim_unity[dim] else x.shape[dim]) expected_values_shape = tuple(expected_values_shape) if values.shape != expected_values_shape: @@ -3931,9 +3936,9 @@ def index_put(context, node): indices = mb.cast(x=indices, dtype="int32") indices = mb.non_zero(x=indices) # values - if len(values.shape) == 0: + if values.shape == (): values = mb.expand_dims(x=values, axes=[0]) - if np.prod(values.shape) <= 1: + if values.rank == 1 and values.shape[0] == 1: reps = value_at(mb.shape(x=indices), 0) reps = mb.expand_dims(x=reps, axes=[0]) values = mb.tile(x=values, reps=reps) @@ -5755,9 +5760,13 @@ def std(context, node): @register_torch_op def copy(context, node): inputs = _get_inputs(context, node, expected=[2, 3]) - if context.frontend == TorchFrontend.TORCHSCRIPT: - result = mb.identity(x=inputs[0], name=node.name) - elif context.frontend == TorchFrontend.EXIR: + assert ( + context.frontend != TorchFrontend.TORCHSCRIPT + ), ( + "In torch script frontend, by graph pass `generate_tensor_assignment_ops`, " + "`torch.copy_` should have been replaced with `_internal_op_tensor_inplace_copy`" + ) + if context.frontend == TorchFrontend.EXIR: src = inputs[1] if inputs[0].shape != src.shape: _, src = _broadcast_tensors(inputs[: 2]) From e0f2377272e6e4c1f9b26900612336ccf388e0ea Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Mon, 8 Apr 2024 20:08:08 -0700 Subject: [PATCH 5/6] fix model pyhpc_turbulent_kinetic_energy: need to promote slice_scatter x and updates --- coremltools/converters/mil/frontend/torch/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index cf4912fc2..1210a2a81 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -3808,8 +3808,7 @@ def select_scatter(context, node): @register_torch_op def slice_scatter(context, node): inputs = _get_inputs(context, node, min_expected=2) - x = inputs[0] - updates = inputs[1] + x, updates = promote_input_dtypes(inputs[0:2]) dim = 0 if len(inputs) <= 2 else inputs[2].val start = 0 if len(inputs) <= 3 else inputs[3] end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim]) From 25a01733b778c10175b8c4f5a0fbd57895692afd Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Tue, 9 Apr 2024 15:43:42 -0700 Subject: [PATCH 6/6] address review comment: remove unnecessary mask --- .../converters/mil/frontend/torch/ops.py | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 1210a2a81..b39ac47c8 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -3782,12 +3782,7 @@ def select_scatter(context, node): begin[dim] = index begin = mb.concat(values=begin, axis=0) end = x.shape - stride = [1] * x.rank - - begin_mask = [True] * x.rank - if index.val not in (0, -x.rank): - begin_mask[dim] = False - end_mask = [True] * x.rank + # and squeeze dim to do pure indexing on it squeeze_mask = [False] * x.rank squeeze_mask[dim] = True @@ -3796,9 +3791,9 @@ def select_scatter(context, node): updates=updates, begin=begin, end=end, - stride=stride, - begin_mask=begin_mask, - end_mask=end_mask, + stride=None, + begin_mask=None, + end_mask=None, squeeze_mask=squeeze_mask, name=node.name, ) @@ -3810,6 +3805,8 @@ def slice_scatter(context, node): inputs = _get_inputs(context, node, min_expected=2) x, updates = promote_input_dtypes(inputs[0:2]) dim = 0 if len(inputs) <= 2 else inputs[2].val + if dim is None: + raise ValueError("Only compile time known dim supported yet") start = 0 if len(inputs) <= 3 else inputs[3] end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim]) step = 1 if len(inputs) <= 5 else inputs[5] @@ -3829,24 +3826,15 @@ def slice_scatter(context, node): steps[dim] = step steps = mb.concat(values=steps, axis=0) - # mb.torch_tensor_assign also have masks - begin_mask = [True] * x.rank - if start.val not in (0, -x.rank): - begin_mask[dim] = False - end_mask = [True] * x.rank - if end.val is None or end.val < x.shape[dim]: - end_mask[dim] = False - squeeze_mask = [False] * x.rank - updated_x = _translate_torch_tensor_assign( x=x, updates=updates, begin=starts, end=ends, stride=steps, - begin_mask=begin_mask, - end_mask=end_mask, - squeeze_mask=squeeze_mask, + begin_mask=None, + end_mask=None, + squeeze_mask=None, name=node.name, ) context.add(updated_x)