diff --git a/backends/transforms/view_copy_to_squeeze_unsqueeze.py b/backends/transforms/view_copy_to_squeeze_unsqueeze.py index 094ec6a3340..f4a0670072c 100644 --- a/backends/transforms/view_copy_to_squeeze_unsqueeze.py +++ b/backends/transforms/view_copy_to_squeeze_unsqueeze.py @@ -46,16 +46,19 @@ def find_squeeze_dims( i = 0 j = 0 idx = [] - while i < len(input_shape): - if input_shape[i] != view_shape[j]: - if input_shape[i] == 1: - idx.append(i) - j -= 1 - # continue to check remaining dims are equal - else: - return None - i += 1 - j += 1 + while i < len(input_shape) and j < len(view_shape): + if input_shape[i] == view_shape[j]: + i += 1 + j += 1 + elif input_shape[i] == 1: + # squeeze axis on i and check next dim + idx.append(i) + i += 1 + else: + return None + # If there are remaining dimensions, shapes do not match + if i < len(input_shape) or j < len(view_shape): + return None return idx def find_unsqueeze_dim(