@@ -26,23 +26,44 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
26
26
}
27
27
const auto & input_shape = ctx.getInputType (0 )->tensor_type ().shape ();
28
28
const auto input_ndim = input_shape.dim_size ();
29
-
29
+ if (input_ndim == 1 ) {
30
+ return ;
31
+ }
30
32
auto output_shape = ctx.getOutputType (0 )->mutable_tensor_type ()->mutable_shape ();
31
33
// This operator only applies to the last dimension; thus -1
32
34
for (int i = 0 ; i < input_ndim - 1 ; ++i) {
33
35
*output_shape->add_dim () = input_shape.dim (i);
34
36
}
35
- // The length of second input is the length of the last dimension of the output
37
+
38
+ // value of the output's last dimension is the total amount of indices
39
+ // set Unknown length for the last dimension if it cannot be calculated
40
+ auto last_dim = output_shape->add_dim ();
36
41
if (hasInputShape (ctx, 1 )) {
37
42
const auto & indices_shape = getInputShape (ctx, 1 );
38
43
if (indices_shape.dim_size () > 0 ) {
39
- auto dim = indices_shape.dim (0 );
40
- *output_shape->add_dim () = dim;
41
- return ;
44
+ int64_t num_indices = 1 ;
45
+ std::string single_symbolic_dim;
46
+ for (int i = 0 ; i < indices_shape.dim_size (); i++) {
47
+ if (indices_shape.dim (i).has_dim_value ()) {
48
+ num_indices *= indices_shape.dim (i).dim_value ();
49
+ } else if (indices_shape.dim (i).has_dim_param ()) {
50
+ if (single_symbolic_dim.empty ()) {
51
+ // it is possible to set symbolic dimension param if the rest dim values are all value 1
52
+ single_symbolic_dim = indices_shape.dim (i).dim_param ();
53
+ } else {
54
+ return ;
55
+ }
56
+ } else {
57
+ return ;
58
+ }
59
+ }
60
+ if (single_symbolic_dim.empty ()) {
61
+ last_dim->set_dim_value (num_indices);
62
+ } else if (num_indices == 1 ) {
63
+ last_dim->set_dim_param (single_symbolic_dim);
64
+ }
42
65
}
43
66
}
44
- // Unknown length of the last dimension
45
- output_shape->add_dim ();
46
67
})
47
68
.TypeConstraint(
48
69
" T" ,
@@ -851,9 +872,9 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
851
872
" Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified." );
852
873
}
853
874
854
- std::vector<std::string> label_strs ;
855
- auto result = getRepeatedAttribute (ctx, " classlabels_strings" , label_strs );
856
- bool using_strings = (result && !label_strs .empty ());
875
+ std::vector<std::string> classlabels_strings ;
876
+ auto result = getRepeatedAttribute (ctx, " classlabels_strings" , classlabels_strings );
877
+ bool using_strings = (result && !classlabels_strings .empty ());
857
878
if (using_strings) {
858
879
updateOutputElemType (ctx, 0 , TensorProto::STRING);
859
880
} else {
@@ -864,10 +885,16 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
864
885
checkInputRank (ctx, 0 , 2 );
865
886
Dim N, E;
866
887
unifyInputDim (ctx, 0 , 0 , N);
867
- std::vector<int64_t > class_ids;
868
- auto has_ids = getRepeatedAttribute (ctx, " class_ids" , class_ids);
869
- if (has_ids) {
870
- unifyDim (E, class_ids.size ());
888
+
889
+ if (using_strings) {
890
+ unifyDim (E, classlabels_strings.size ());
891
+ } else {
892
+ std::vector<int64_t > classlabels_int64s;
893
+ result = getRepeatedAttribute (ctx, " classlabels_int64s" , classlabels_int64s);
894
+ if (!result || classlabels_int64s.empty ()) {
895
+ fail_shape_inference (" Non of classlabels_int64s or classlabels_strings is set." );
896
+ }
897
+ unifyDim (E, classlabels_int64s.size ());
871
898
}
872
899
updateOutputShape (ctx, 0 , {N});
873
900
updateOutputShape (ctx, 1 , {N, E});
0 commit comments