@@ -395,100 +395,38 @@ def index_select(
395
395
return gather_layer .get_output (0 )
396
396
397
397
398
- def scatter_value (
398
+ def scatter (
399
399
ctx : ConversionContext ,
400
400
target : Target ,
401
401
source_ir : Optional [SourceIR ],
402
402
name : str ,
403
403
input : TRTTensor ,
404
404
dim : int ,
405
405
index : Union [TRTTensor , np .ndarray , torch .Tensor ],
406
- value : float ,
406
+ src : Union [ TRTTensor , int , float ] ,
407
407
) -> TRTTensor :
408
- if not isinstance (input , TRTTensor ):
409
- raise RuntimeError (
410
- f"scatter_tensor received input { input } that is not part "
411
- "of the TensorRT region!"
412
- )
413
408
input_shape = input .shape
414
409
index_shape = index .shape
415
410
index_shape_list = list (index .shape )
416
411
if not (isinstance (index , TRTTensor )):
417
412
index = get_trt_tensor (ctx , index , f"_index_tensor" )
418
- if len (input_shape ) != len (index_shape ):
419
- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
420
413
dim = get_positive_dim (dim , len (input_shape ))
421
414
dynamic_shape = has_dynamic_shape (input .shape )
422
415
if dynamic_shape :
423
416
# Check whether slice target dim is dynamic shape dim
424
417
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
425
418
426
- input_dims = len (input_shape )
427
- for i in range (0 , input_dims ):
428
- if i != dim and (index_shape [i ] >= input .shape [i ]):
429
- raise RuntimeError (
430
- f"cannot have index size greater than the input size along dimension { dim } "
431
- )
432
-
433
- value_tensor = get_trt_tensor (
434
- ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
435
- )
436
- value_tensor = cast_trt_tensor (
437
- ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
438
- )
439
- scatter_layer = ctx .net .add_scatter (
440
- input , index , value_tensor , trt .ScatterMode .ELEMENT
441
- )
442
- scatter_layer .axis = dim
443
- set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
444
- out = scatter_layer .get_output (0 )
445
- return out
446
-
447
-
448
- def scatter_src (
449
- ctx : ConversionContext ,
450
- target : Target ,
451
- source_ir : Optional [SourceIR ],
452
- name : str ,
453
- input : TRTTensor ,
454
- dim : Shape ,
455
- index : Shape ,
456
- src : TRTTensor ,
457
- ) -> TRTTensor :
458
- if not isinstance (input , TRTTensor ):
459
- raise RuntimeError (
460
- f"scatter_tensor received input { input } that is not part "
461
- "of the TensorRT region!"
462
- )
463
- input_shape = input .shape
464
- index_shape = index .shape
465
- src_shape = src .shape
466
- if not (isinstance (index , TRTTensor )):
467
- index = get_trt_tensor (ctx , index , f"_index_tensor" )
468
- if len (input_shape ) != len (index_shape ):
469
- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
470
- if len (index_shape ) != len (src_shape ):
471
- raise RuntimeError (f"The no of dimensions of src and index should be equal" )
472
-
473
- input_dims = len (input_shape )
474
- dim = get_positive_dim (cast (int , dim ), input_dims )
475
- dynamic_shape = has_dynamic_shape (input .shape )
476
- if dynamic_shape :
477
- # Check whether slice target dim is dynamic shape dim
478
- assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
479
-
480
- for i in range (0 , input_dims ):
481
- if i != dim and (index_shape [i ] >= input .shape [i ]):
482
- raise RuntimeError (
483
- f"cannot have index size greater than the input size along dimension { dim } "
484
- )
485
- input_dtype = input .dtype
486
- # required for cases where src is a constant
487
- src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
488
- if input_dtype != src_dtype :
489
- raise RuntimeError (f"The type of input and src should be made" )
490
419
src_tensor = src
491
- if not (isinstance (src , TRTTensor )):
420
+ # scatter.value
421
+ if isinstance (src , int ) or isinstance (src , float ):
422
+ src_tensor = get_trt_tensor (
423
+ ctx , src * torch .ones (index_shape_list ), name + "_value_tensor"
424
+ )
425
+ src_tensor = cast_trt_tensor (
426
+ ctx , src_tensor , input .dtype , name + "_cast_value_tensor"
427
+ )
428
+ # scatter.src
429
+ elif not (isinstance (src , TRTTensor )):
492
430
src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
493
431
494
432
scatter_layer = ctx .net .add_scatter (
0 commit comments