6
6
# This import monkey-patches graph manipulation methods on Graph, used for the
7
7
# ONNX symbolics
8
8
import torch .onnx .utils
9
-
10
9
from functools import partial
11
10
from functools import wraps
12
11
@@ -421,7 +420,7 @@ def cumsum(g, input, dim, dtype):
421
420
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
422
421
if dtype .node ().kind () != "prim::Constant" :
423
422
return _unimplemented (name , "dtype" )
424
- return g .op ( "ATen " , input , operator_s = "cumsum" , dim_i = dim )
423
+ return g .at ( "cumsum " , input , dim_i = dim )
425
424
else :
426
425
sym_help ._onnx_opset_unsupported ("cumsum" , 9 , 11 )
427
426
@@ -431,7 +430,7 @@ def _sample_dirichlet(g, self, generator):
431
430
if not sym_help ._is_none (generator ):
432
431
return _unimplemented ("_sample_dirichlet" ,
433
432
"We are not able to export generator" )
434
- return g .op ( "ATen " , self , operator_s = "_sample_dirichlet" )
433
+ return g .at ( "_sample_dirichlet " , self )
435
434
else :
436
435
return sym_help ._onnx_unsupported ("_sample_dirichlet" )
437
436
@@ -441,7 +440,7 @@ def _standard_gamma(g, self, generator):
441
440
if not sym_help ._is_none (generator ):
442
441
return _unimplemented ("_standard_gamma" ,
443
442
"We are not able to export generator" )
444
- return g .op ( "ATen " , self , operator_s = "_standard_gamma" )
443
+ return g .at ( "_standard_gamma " , self )
445
444
else :
446
445
return sym_help ._onnx_unsupported ("_standard_gamma" )
447
446
@@ -508,11 +507,10 @@ def embedding_bag(g,
508
507
if not sym_help ._is_none (per_sample_weights ):
509
508
return sym_help ._onnx_unsupported ("embedding_bag with per_sample_weights" )
510
509
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
511
- return g .op ( "ATen " ,
510
+ return g .at ( "embedding_bag " ,
512
511
embedding_matrix ,
513
512
indices ,
514
513
offsets ,
515
- operator_s = "embedding_bag" ,
516
514
outputs = 4 ,
517
515
scale_grad_by_freq_i = scale_grad_by_freq ,
518
516
mode_i = mode ,
@@ -549,7 +547,7 @@ def transpose(g, self, dim0, dim1):
549
547
# if we don't have dim information we cannot
550
548
# output a permute so use ATen instead
551
549
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
552
- return g .op ( "ATen " , self , operator_s = "transpose" , dim0_i = dim0 , dim1_i = dim1 )
550
+ return g .at ( "transpose " , self , dim0_i = dim0 , dim1_i = dim1 , overload_name = "int" )
553
551
else :
554
552
raise RuntimeError ("Unsupported: ONNX export of transpose for tensor "
555
553
"of unknown rank." )
@@ -1358,8 +1356,8 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
1358
1356
@parse_args ("v" , "is" , "v" , "v" , "f" , "i" )
1359
1357
def layer_norm (g , input , normalized_shape , weight , bias , eps , cudnn_enable ):
1360
1358
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1361
- return g .op ( "ATen " , input , weight , bias , normalized_shape_i = normalized_shape ,
1362
- eps_f = eps , cudnn_enable_i = cudnn_enable , operator_s = "layer_norm" )
1359
+ return g .at ( "layer_norm " , input , weight , bias , normalized_shape_i = normalized_shape ,
1360
+ eps_f = eps , cudnn_enable_i = cudnn_enable )
1363
1361
1364
1362
axes = [- i for i in range (len (normalized_shape ), 0 , - 1 )]
1365
1363
@@ -1428,7 +1426,7 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s
1428
1426
@parse_args ("v" , "i" , "i" , "i" )
1429
1427
def unfold (g , input , dimension , size , step ):
1430
1428
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1431
- return g .op ( "ATen " , input , operator_s = "unfold" , dimension_i = dimension , size_i = size , step_i = step )
1429
+ return g .at ( "unfold " , input , dimension_i = dimension , size_i = size , step_i = step )
1432
1430
sizes = sym_help ._get_tensor_sizes (input )
1433
1431
try :
1434
1432
sizedim = sizes [dimension ]
@@ -1477,7 +1475,7 @@ def index_put(g, self, indices_list_value, values, accumulate):
1477
1475
indices_list = [indices_list_value ]
1478
1476
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1479
1477
args = [self ] + indices_list + [values , accumulate ]
1480
- return g .op ( "ATen " , * args , operator_s = "index_put" )
1478
+ return g .at ( "index_put " , * args )
1481
1479
1482
1480
accumulate = sym_help ._parse_arg (accumulate , "b" )
1483
1481
@@ -1493,7 +1491,7 @@ def index_put(g, self, indices_list_value, values, accumulate):
1493
1491
def index_fill (g , self , dim , index , value ):
1494
1492
dim_value = sym_help ._parse_arg (dim , "i" )
1495
1493
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1496
- return g .op ( "ATen " , self , index , value , dim_i = dim_value , operator_s = "index_fill " )
1494
+ return g .at ( "index_fill " , self , index , value , dim_i = dim_value , overload_name = "int_Scalar " )
1497
1495
expanded_index_shape , expanded_index = sym_help ._index_fill_reshape_helper (g , self , dim , index )
1498
1496
value = sym_help ._maybe_get_scalar (value )
1499
1497
value = sym_help ._if_scalar_type_as (g , value , self )
@@ -1505,7 +1503,7 @@ def index_fill(g, self, dim, index, value):
1505
1503
def index_copy (g , self , dim , index , source ):
1506
1504
dim_value = sym_help ._parse_arg (dim , "i" )
1507
1505
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1508
- return g .op ( "ATen " , self , index , source , dim_i = dim_value , operator_s = "index_copy" )
1506
+ return g .at ( "index_copy " , self , index , source , dim_i = dim_value )
1509
1507
expanded_index_shape , expanded_index = sym_help ._index_fill_reshape_helper (g , self , dim , index )
1510
1508
return scatter (g , self , dim , expanded_index , source )
1511
1509
@@ -1520,7 +1518,7 @@ def type_as(g, self, other):
1520
1518
else :
1521
1519
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1522
1520
# We don't know the type of other, bail by emitting ATen
1523
- return g .op ( "ATen " , self , other , operator_s = "type_as" )
1521
+ return g .at ( "type_as " , self , other )
1524
1522
else :
1525
1523
raise RuntimeError ("Unsupported: ONNX export of type_as for tensor "
1526
1524
"of unknown dtype. Please check if the dtype of the "
@@ -1530,7 +1528,7 @@ def type_as(g, self, other):
1530
1528
@parse_args ("v" , "v" , "i" , "f" )
1531
1529
def cosine_similarity (g , x1 , x2 , dim , eps ):
1532
1530
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1533
- return g .op ( "ATen " , x1 , x2 , dim_i = dim , eps_f = eps , operator_s = "cosine_similarity" )
1531
+ return g .at ( "cosine_similarity " , x1 , x2 , dim_i = dim , eps_f = eps )
1534
1532
else :
1535
1533
return sym_help ._onnx_unsupported ("cosine_similarity" )
1536
1534
@@ -1687,7 +1685,7 @@ def norm(g, self, p, dim, keepdim):
1687
1685
@parse_args ("v" , "v" , "v" , "i" )
1688
1686
def conv_tbc (g , input , weight , bias , pad ):
1689
1687
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1690
- return g .op ( "ATen " , input , weight , bias , operator_s = "conv_tbc" , pad_i = pad )
1688
+ return g .at ( "conv_tbc " , input , weight , bias , pad_i = pad )
1691
1689
else :
1692
1690
# input must have 3 dimensions, see:
1693
1691
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
@@ -1703,7 +1701,7 @@ def conv_tbc(g, input, weight, bias, pad):
1703
1701
@parse_args ("v" , "i" , "i" )
1704
1702
def _unique (g , input , sorted , return_inverse ):
1705
1703
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1706
- return g .op ( "ATen " , input , operator_s = "_unique" , sorted_i = sorted ,
1704
+ return g .at ( "_unique " , input , sorted_i = sorted ,
1707
1705
return_inverse_i = return_inverse , outputs = 2 )
1708
1706
else :
1709
1707
return sym_help ._onnx_unsupported ("_unique" )
@@ -1712,7 +1710,7 @@ def _unique(g, input, sorted, return_inverse):
1712
1710
@parse_args ("v" , "i" , "i" , "i" )
1713
1711
def _unique2 (g , input , sorted , return_inverse , return_counts ):
1714
1712
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
1715
- return g .op ( "ATen " , input , operator_s = "_unique2" , sorted_i = sorted ,
1713
+ return g .at ( "_unique2 " , input , sorted_i = sorted ,
1716
1714
return_inverse_i = return_inverse , return_counts_i = return_counts ,
1717
1715
outputs = 3 )
1718
1716
else :
@@ -2725,7 +2723,7 @@ def logsumexp(g, input, dim, keepdim):
2725
2723
2726
2724
def arange (g , * args ):
2727
2725
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
2728
- return g .op ( "ATen " , * args , operator_s = "arange" )
2726
+ return g .at ( "arange " , * args )
2729
2727
2730
2728
def _get_arange_dtype (dtype ):
2731
2729
dtype = sym_help ._maybe_get_const (dtype , "i" )
@@ -2788,7 +2786,7 @@ def masked_fill(g, self, mask, value):
2788
2786
2789
2787
def index (g , self , index ):
2790
2788
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
2791
- return g .op ( "ATen " , self , index , operator_s = "index " )
2789
+ return g .at ( "index " , self , index , overload_name = "Tensor " )
2792
2790
2793
2791
if sym_help ._is_packed_list (index ):
2794
2792
indices = sym_help ._unpack_list (index )
@@ -2963,8 +2961,8 @@ def gelu(g, self):
2963
2961
@parse_args ("v" , "i" , "v" , "v" , "f" , "i" )
2964
2962
def group_norm (g , input , num_groups , weight , bias , eps , cudnn_enabled ):
2965
2963
if sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
2966
- return g .op ( "ATen " , input , weight , bias , num_groups_i = num_groups ,
2967
- eps_f = eps , cudnn_enabled_i = cudnn_enabled , operator_s = "group_norm" )
2964
+ return g .at ( "group_norm " , input , weight , bias , num_groups_i = num_groups ,
2965
+ eps_f = eps , cudnn_enabled_i = cudnn_enabled )
2968
2966
2969
2967
channel_size = sym_help ._get_tensor_dim_size (input , 1 )
2970
2968
if channel_size is not None :
@@ -3021,7 +3019,7 @@ def _weight_norm(g, weight_v, weight_g, dim):
3021
3019
div = g .op ("Div" , weight_v , norm_v )
3022
3020
return g .op ("Mul" , div , weight_g )
3023
3021
elif sym_help ._operator_export_type == torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK :
3024
- return g .op ( "ATen " , weight_v , weight_g , dim_i = dim , operator_s = "_weight_norm" )
3022
+ return g .at ( "_weight_norm " , weight_v , weight_g , dim_i = dim )
3025
3023
else :
3026
3024
raise RuntimeError ("Unsupported: ONNX export of _weight_norm for tensor "
3027
3025
"of unknown rank." )
0 commit comments