@@ -389,25 +389,25 @@ def scatter_value(
389
389
)
390
390
input_shape = input .shape
391
391
index_shape = index .shape
392
- if (len (input_shape ) != len (index_shape )):
393
- raise RuntimeError (
394
- f"The no of dimensions of input and index should be equal"
395
- )
392
+ if len (input_shape ) != len (index_shape ):
393
+ raise RuntimeError (f"The no of dimensions of input and index should be equal" )
396
394
ranks = len (input_shape )
397
395
dim = get_positive_dim (cast (int , dim ), ranks )
398
396
dynamic_shape = has_dynamic_shape (input .shape )
399
397
if dynamic_shape :
400
398
# Check whether slice target dim is dynamic shape dim
401
399
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
402
-
400
+
403
401
input_dims = len (input .shape )
404
402
for i in range (0 , input_dims ):
405
403
if index [i ] >= input .shape [i ]:
406
404
raise RuntimeError (
407
405
f"cannot have index greater than the dimension length! { input .shape [dim ]} "
408
406
)
409
407
value_tensor = value * torch .ones (index .shape )
410
- scatter_layer = ctx .net .add_scatter (input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT )
408
+ scatter_layer = ctx .net .add_scatter (
409
+ input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT
410
+ )
411
411
scatter_layer .set_axis (dim )
412
412
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
413
413
out = scatter_layer .get_output (0 )
@@ -432,29 +432,27 @@ def scatter_src(
432
432
input_shape = input .shape
433
433
index_shape = index .shape
434
434
src_shape = src .shape
435
- if (len (input_shape ) != len (index_shape )):
436
- raise RuntimeError (
437
- f"The no of dimensions of input and index should be equal"
438
- )
439
- if (len (index_shape ) != len (src_shape )):
440
- raise RuntimeError (
441
- f"The no of dimensions of src and index should be equal"
442
- )
443
-
435
+ if len (input_shape ) != len (index_shape ):
436
+ raise RuntimeError (f"The no of dimensions of input and index should be equal" )
437
+ if len (index_shape ) != len (src_shape ):
438
+ raise RuntimeError (f"The no of dimensions of src and index should be equal" )
439
+
444
440
input_dims = len (input_shape )
445
441
dim = get_positive_dim (cast (int , dim ), input_dims )
446
442
dynamic_shape = has_dynamic_shape (input .shape )
447
443
if dynamic_shape :
448
444
# Check whether slice target dim is dynamic shape dim
449
445
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
450
-
446
+
451
447
for i in range (0 , input_dims ):
452
448
if index [i ] >= input .shape [i ]:
453
449
raise RuntimeError (
454
450
f"cannot have index greater than the dimension length! { input .shape [dim ]} "
455
451
)
456
- scatter_layer = ctx .net .add_scatter (input , index , src , trt .tensorrt .ScatterModekELEMENT )
452
+ scatter_layer = ctx .net .add_scatter (
453
+ input , index , src , trt .tensorrt .ScatterModekELEMENT
454
+ )
457
455
scatter_layer .set_axis (dim )
458
456
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
459
457
out = scatter_layer .get_output (0 )
460
- return out
458
+ return out
0 commit comments