9
9
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
10
10
from torch_tensorrt .dynamo .conversion .converter_utils import (
11
11
broadcastable ,
12
+ cast_trt_tensor ,
12
13
get_positive_dim ,
13
14
get_trt_tensor ,
14
15
to_numpy ,
20
21
set_layer_name ,
21
22
)
22
23
from torch_tensorrt .fx .types import Shape , TRTTensor
24
+ from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
23
25
24
26
_LOGGER : logging .Logger = logging .getLogger (__name__ )
25
27
@@ -378,8 +380,8 @@ def scatter_value(
378
380
source_ir : Optional [SourceIR ],
379
381
name : str ,
380
382
input : TRTTensor ,
381
- dim : Shape ,
382
- index : Shape ,
383
+ dim : int ,
384
+ index : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
383
385
value : float ,
384
386
) -> TRTTensor :
385
387
if not isinstance (input , TRTTensor ):
@@ -389,26 +391,34 @@ def scatter_value(
389
391
)
390
392
input_shape = input .shape
391
393
index_shape = index .shape
394
+ index_shape_list = list (index .shape )
395
+ if not (isinstance (index , TRTTensor )):
396
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
392
397
if len (input_shape ) != len (index_shape ):
393
398
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
394
- ranks = len (input_shape )
395
- dim = get_positive_dim (cast (int , dim ), ranks )
399
+ dim = get_positive_dim (dim , len (input_shape ))
396
400
dynamic_shape = has_dynamic_shape (input .shape )
397
401
if dynamic_shape :
398
402
# Check whether slice target dim is dynamic shape dim
399
403
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
400
404
401
- input_dims = len (input . shape )
405
+ input_dims = len (input_shape )
402
406
for i in range (0 , input_dims ):
403
- if index [i ] >= input .shape [i ]:
407
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
404
408
raise RuntimeError (
405
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
409
+ f"cannot have index size greater than the input size along dimension { dim } "
406
410
)
407
- value_tensor = value * torch .ones (index .shape )
411
+
412
+ value_tensor = get_trt_tensor (
413
+ ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
414
+ )
415
+ value_tensor = cast_trt_tensor (
416
+ ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
417
+ )
408
418
scatter_layer = ctx .net .add_scatter (
409
- input , index , value_tensor , trt .tensorrt . ScatterModekELEMENT
419
+ input , index , value_tensor , trt .ScatterMode . ELEMENT
410
420
)
411
- scatter_layer .set_axis ( dim )
421
+ scatter_layer .axis = dim
412
422
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
413
423
out = scatter_layer .get_output (0 )
414
424
return out
@@ -432,6 +442,8 @@ def scatter_src(
432
442
input_shape = input .shape
433
443
index_shape = index .shape
434
444
src_shape = src .shape
445
+ if not (isinstance (index , TRTTensor )):
446
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
435
447
if len (input_shape ) != len (index_shape ):
436
448
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
437
449
if len (index_shape ) != len (src_shape ):
@@ -445,14 +457,23 @@ def scatter_src(
445
457
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
446
458
447
459
for i in range (0 , input_dims ):
448
- if index [i ] >= input .shape [i ]:
460
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
449
461
raise RuntimeError (
450
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
462
+ f"cannot have index size greater than the input size along dimension { dim } "
451
463
)
464
+ input_dtype = input .dtype
465
+ # required for cases where src is a constant
466
+ src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
467
+ if input_dtype != src_dtype :
468
+ raise RuntimeError (f"The type of input and src should be made" )
469
+ src_tensor = src
470
+ if not (isinstance (src , TRTTensor )):
471
+ src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
472
+
452
473
scatter_layer = ctx .net .add_scatter (
453
- input , index , src , trt .tensorrt . ScatterModekELEMENT
474
+ input , index , src_tensor , trt .ScatterMode . ELEMENT
454
475
)
455
- scatter_layer .set_axis ( dim )
476
+ scatter_layer .axis = dim
456
477
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
457
478
out = scatter_layer .get_output (0 )
458
479
return out
0 commit comments