Skip to content

Commit 7adb421

Browse files
p-wysockiliqunfu
andauthored
misc fixes for issues found in ort integration (#4681) (#4695)
Signed-off-by: p-wysocki <[email protected]> Signed-off-by: p-wysocki <[email protected]> Co-authored-by: liqun Fu <[email protected]>
1 parent a913015 commit 7adb421

File tree

4 files changed

+48
-16
lines changed

4 files changed

+48
-16
lines changed

onnx/defs/schema.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,10 @@ class OpSchemaRegistry final : public ISchemaRegistry {
10331033
auto& op_name = op_schema.Name();
10341034
auto& op_domain = op_schema.domain();
10351035
auto ver = op_schema.SinceVersion();
1036+
if (OpSchema::kUninitializedSinceVersion == ver) {
1037+
op_schema.SinceVersion(1);
1038+
ver = op_schema.SinceVersion();
1039+
}
10361040
// Stops because the opset_version is higher than opset_version_to_load
10371041
if (opset_version_to_load != 0 && ver > opset_version_to_load) {
10381042
return;

onnx/defs/traditionalml/defs.cc

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,44 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
2626
}
2727
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
2828
const auto input_ndim = input_shape.dim_size();
29-
29+
if (input_ndim == 1) {
30+
return;
31+
}
3032
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
3133
// This operator only applies to the last dimension; thus -1
3234
for (int i = 0; i < input_ndim - 1; ++i) {
3335
*output_shape->add_dim() = input_shape.dim(i);
3436
}
35-
// The length of second input is the length of the last dimension of the output
37+
38+
// value of the output's last dimension is the total amount of indices
39+
// set Unknown length for the last dimension if it cannot be calculated
40+
auto last_dim = output_shape->add_dim();
3641
if (hasInputShape(ctx, 1)) {
3742
const auto& indices_shape = getInputShape(ctx, 1);
3843
if (indices_shape.dim_size() > 0) {
39-
auto dim = indices_shape.dim(0);
40-
*output_shape->add_dim() = dim;
41-
return;
44+
int64_t num_indices = 1;
45+
std::string single_symbolic_dim;
46+
for (int i = 0; i < indices_shape.dim_size(); i++) {
47+
if (indices_shape.dim(i).has_dim_value()) {
48+
num_indices *= indices_shape.dim(i).dim_value();
49+
} else if (indices_shape.dim(i).has_dim_param()) {
50+
if (single_symbolic_dim.empty()) {
51+
// it is possible to set symbolic dimension param if the rest dim values are all value 1
52+
single_symbolic_dim = indices_shape.dim(i).dim_param();
53+
} else {
54+
return;
55+
}
56+
} else {
57+
return;
58+
}
59+
}
60+
if (single_symbolic_dim.empty()) {
61+
last_dim->set_dim_value(num_indices);
62+
} else if (num_indices == 1) {
63+
last_dim->set_dim_param(single_symbolic_dim);
64+
}
4265
}
4366
}
44-
// Unknown length of the last dimension
45-
output_shape->add_dim();
4667
})
4768
.TypeConstraint(
4869
"T",
@@ -851,9 +872,9 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
851872
"Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
852873
}
853874

854-
std::vector<std::string> label_strs;
855-
auto result = getRepeatedAttribute(ctx, "classlabels_strings", label_strs);
856-
bool using_strings = (result && !label_strs.empty());
875+
std::vector<std::string> classlabels_strings;
876+
auto result = getRepeatedAttribute(ctx, "classlabels_strings", classlabels_strings);
877+
bool using_strings = (result && !classlabels_strings.empty());
857878
if (using_strings) {
858879
updateOutputElemType(ctx, 0, TensorProto::STRING);
859880
} else {
@@ -864,10 +885,16 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
864885
checkInputRank(ctx, 0, 2);
865886
Dim N, E;
866887
unifyInputDim(ctx, 0, 0, N);
867-
std::vector<int64_t> class_ids;
868-
auto has_ids = getRepeatedAttribute(ctx, "class_ids", class_ids);
869-
if (has_ids) {
870-
unifyDim(E, class_ids.size());
888+
889+
if (using_strings) {
890+
unifyDim(E, classlabels_strings.size());
891+
} else {
892+
std::vector<int64_t> classlabels_int64s;
893+
result = getRepeatedAttribute(ctx, "classlabels_int64s", classlabels_int64s);
894+
if (!result || classlabels_int64s.empty()) {
895+
fail_shape_inference("Non of classlabels_int64s or classlabels_strings is set.");
896+
}
897+
unifyDim(E, classlabels_int64s.size());
871898
}
872899
updateOutputShape(ctx, 0, {N});
873900
updateOutputShape(ctx, 1, {N, E});

onnx/test/shape_inference_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8567,7 +8567,7 @@ def test_tree_ensemble_classifier(self) -> None:
85678567
"TreeEnsembleClassifier",
85688568
["x"],
85698569
["y", "z"],
8570-
class_ids=[0, 1, 2, 3, 4],
8570+
classlabels_int64s=[0, 1, 2, 3, 4],
85718571
domain=ONNX_ML_DOMAIN,
85728572
)
85738573
graph = self._make_graph(

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ max-line-length = 88
3131
# type comments.
3232
# E203 is need to support black formatting.
3333
# https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html
34+
# B905 `zip()` without an explicit `strict=` parameter.
3435
# B950 as we have too many lines too long. This is ok because black handles most cases.
35-
ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401, E203, B950
36+
ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401, E203, B905, B950
3637
exclude =
3738
.git,
3839
__pycache__,

0 commit comments

Comments
 (0)