8
8
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
9
9
from torch_tensorrt .dynamo .conversion .converter_utils import (
10
10
broadcastable ,
11
+ cast_trt_tensor ,
11
12
get_positive_dim ,
12
13
get_trt_tensor ,
13
14
to_numpy ,
19
20
set_layer_name ,
20
21
)
21
22
from torch_tensorrt .fx .types import Shape , TRTTensor
23
+ from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
22
24
23
25
import tensorrt as trt
24
26
@@ -399,8 +401,8 @@ def scatter_value(
399
401
source_ir : Optional [SourceIR ],
400
402
name : str ,
401
403
input : TRTTensor ,
402
- dim : Shape ,
403
- index : Shape ,
404
+ dim : int ,
405
+ index : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
404
406
value : float ,
405
407
) -> TRTTensor :
406
408
if not isinstance (input , TRTTensor ):
@@ -410,26 +412,34 @@ def scatter_value(
410
412
)
411
413
input_shape = input .shape
412
414
index_shape = index .shape
415
+ index_shape_list = list (index .shape )
416
+ if not (isinstance (index , TRTTensor )):
417
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
413
418
if len (input_shape ) != len (index_shape ):
414
419
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
415
- ranks = len (input_shape )
416
- dim = get_positive_dim (cast (int , dim ), ranks )
420
+ dim = get_positive_dim (dim , len (input_shape ))
417
421
dynamic_shape = has_dynamic_shape (input .shape )
418
422
if dynamic_shape :
419
423
# Check whether slice target dim is dynamic shape dim
420
424
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
421
425
422
- input_dims = len (input . shape )
426
+ input_dims = len (input_shape )
423
427
for i in range (0 , input_dims ):
424
- if index [i ] >= input .shape [i ]:
428
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
425
429
raise RuntimeError (
426
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
430
+ f"cannot have index size greater than the input size along dimension { dim } "
427
431
)
428
- value_tensor = value * torch .ones (index .shape )
432
+
433
+ value_tensor = get_trt_tensor (
434
+ ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
435
+ )
436
+ value_tensor = cast_trt_tensor (
437
+ ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
438
+ )
429
439
scatter_layer = ctx .net .add_scatter (
430
- input , index , value_tensor , trt .tensorrt . ScatterModekELEMENT
440
+ input , index , value_tensor , trt .ScatterMode . ELEMENT
431
441
)
432
- scatter_layer .set_axis ( dim )
442
+ scatter_layer .axis = dim
433
443
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
434
444
out = scatter_layer .get_output (0 )
435
445
return out
@@ -453,6 +463,8 @@ def scatter_src(
453
463
input_shape = input .shape
454
464
index_shape = index .shape
455
465
src_shape = src .shape
466
+ if not (isinstance (index , TRTTensor )):
467
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
456
468
if len (input_shape ) != len (index_shape ):
457
469
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
458
470
if len (index_shape ) != len (src_shape ):
@@ -466,14 +478,23 @@ def scatter_src(
466
478
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
467
479
468
480
for i in range (0 , input_dims ):
469
- if index [i ] >= input .shape [i ]:
481
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
470
482
raise RuntimeError (
471
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
483
+ f"cannot have index size greater than the input size along dimension { dim } "
472
484
)
485
+ input_dtype = input .dtype
486
+ # required for cases where src is a constant
487
+ src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
488
+ if input_dtype != src_dtype :
489
+ raise RuntimeError (f"The type of input and src should be made" )
490
+ src_tensor = src
491
+ if not (isinstance (src , TRTTensor )):
492
+ src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
493
+
473
494
scatter_layer = ctx .net .add_scatter (
474
- input , index , src , trt .tensorrt . ScatterModekELEMENT
495
+ input , index , src_tensor , trt .ScatterMode . ELEMENT
475
496
)
476
- scatter_layer .set_axis ( dim )
497
+ scatter_layer .axis = dim
477
498
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
478
499
out = scatter_layer .get_output (0 )
479
500
return out
0 commit comments