Skip to content

Commit a4e506d

Browse files
committed
Linting fix
1 parent 1d8af0c commit a4e506d

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -389,25 +389,25 @@ def scatter_value(
389389
)
390390
input_shape = input.shape
391391
index_shape = index.shape
392-
if (len(input_shape) != len(index_shape)):
393-
raise RuntimeError(
394-
f"The no of dimensions of input and index should be equal"
395-
)
392+
if len(input_shape) != len(index_shape):
393+
raise RuntimeError(f"The no of dimensions of input and index should be equal")
396394
ranks = len(input_shape)
397395
dim = get_positive_dim(cast(int, dim), ranks)
398396
dynamic_shape = has_dynamic_shape(input.shape)
399397
if dynamic_shape:
400398
# Check whether slice target dim is dynamic shape dim
401399
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
402-
400+
403401
input_dims = len(input.shape)
404402
for i in range(0, input_dims):
405403
if index[i] >= input.shape[i]:
406404
raise RuntimeError(
407405
f"cannot have index greater than the dimension length! {input.shape[dim]}"
408406
)
409407
value_tensor = value * torch.ones(index.shape)
410-
scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT)
408+
scatter_layer = ctx.net.add_scatter(
409+
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
410+
)
411411
scatter_layer.set_axis(dim)
412412
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
413413
out = scatter_layer.get_output(0)
@@ -432,29 +432,27 @@ def scatter_src(
432432
input_shape = input.shape
433433
index_shape = index.shape
434434
src_shape = src.shape
435-
if (len(input_shape) != len(index_shape)):
436-
raise RuntimeError(
437-
f"The no of dimensions of input and index should be equal"
438-
)
439-
if (len(index_shape) != len(src_shape)):
440-
raise RuntimeError(
441-
f"The no of dimensions of src and index should be equal"
442-
)
443-
435+
if len(input_shape) != len(index_shape):
436+
raise RuntimeError(f"The no of dimensions of input and index should be equal")
437+
if len(index_shape) != len(src_shape):
438+
raise RuntimeError(f"The no of dimensions of src and index should be equal")
439+
444440
input_dims = len(input_shape)
445441
dim = get_positive_dim(cast(int, dim), input_dims)
446442
dynamic_shape = has_dynamic_shape(input.shape)
447443
if dynamic_shape:
448444
# Check whether slice target dim is dynamic shape dim
449445
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
450-
446+
451447
for i in range(0, input_dims):
452448
if index[i] >= input.shape[i]:
453449
raise RuntimeError(
454450
f"cannot have index greater than the dimension length! {input.shape[dim]}"
455451
)
456-
scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT)
452+
scatter_layer = ctx.net.add_scatter(
453+
input, index, src, trt.tensorrt.ScatterModekELEMENT
454+
)
457455
scatter_layer.set_axis(dim)
458456
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
459457
out = scatter_layer.get_output(0)
460-
return out
458+
return out

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self):
2323
def forward(self, input, src):
2424
return torch.ops.aten.scatter.value(input, dim, index, value)
2525

26-
input = [torch.zeros(3, 5, dtype = torch.int32)]
26+
input = [torch.zeros(3, 5, dtype=torch.int32)]
2727
self.run_test(
2828
TestModule(),
2929
input,
@@ -46,14 +46,11 @@ def __init__(self):
4646

4747
def forward(self, input, src):
4848
return torch.ops.aten.scatter.src(input, dim, index, src)
49-
50-
src = [torch.arange(1, 11).reshape((2,5))]
51-
input = torch.zeros(3, 5, dtype = src.dtype)
49+
50+
src = [torch.arange(1, 11).reshape((2, 5))]
51+
input = torch.zeros(3, 5, dtype=src.dtype)
5252
inputs = [input, src]
5353
self.run_test(
5454
TestModule(),
5555
inputs,
5656
)
57-
58-
59-

0 commit comments

Comments
 (0)