diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index b78ee94b774..2b345cb5118 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -46,24 +46,26 @@ def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: (1,0,2) """ N = len(permutation_matrix) - assert N == len( - permutation_matrix[0] - ), f"A permutation matrix must be square, got shape {permutation_matrix.shape}" + if N != len(permutation_matrix[0]): + raise ValueError( + f"A permutation matrix must be square, got shape {permutation_matrix.shape}" + ) p = [0] * N for row_index, row in enumerate(permutation_matrix): saw_one = False for col_index, value in enumerate(row): if value == 1: - assert ( - not saw_one - ), f"A permutation matrix can only have one 1 per row, got row {row}." + if saw_one: + raise ValueError( + f"A permutation matrix can only have one 1 per row, got {row=}" + ) p[row_index] = col_index saw_one = True - else: - assert ( - value == 0 - ), f"A permutation matrix only contains 1's and 0's, got value {value}." + elif value != 0: + raise ValueError( + f"A permutation matrix only contains 1's and 0's, got {value=}" + ) return p