Skip to content

[Feature] Support more than 1 input for VotingClassifier #1016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
eddiebergman opened this issue Aug 18, 2023 · 1 comment
Open

[Feature] Support more than 1 input for VotingClassifier #1016

eddiebergman opened this issue Aug 18, 2023 · 1 comment

Comments

@eddiebergman
Copy link

Hi there,

I am new to onnx in general so apologies if the issue is misplaced or I am missing something fundamental.

I'm coming from the tool autosklearn and planning to introduce some basic onnx support by exporting found models after doing some optimization over possible pipelines. These pipelines will mostly consist of an ensemble (VotingClassifier) which they themselves contain Pipelines with disjoint imputation strategies, feature preprocessing and estimators.

Based on the error below, it seems that using a VotingClassifier would require all features to be numeric (or at least of the same TensorType) to be viable? Is this correct? Is there something fundamental which would prevent the SklearnVotingClassifier operator from working with more than 1 input?

I am linking to this issue here in case anyone using autosklearn would like to enable onnx support and would be able to contribute! I've included a reproducible example and the traceback

Reproducible Example

Apologies for using openml, sklearn toy datasets do not have such varied column types.

from __future__ import annotations


def main():
    import openml
    from mlprodict.onnx_conv import guess_schema_from_data
    from onnxruntime import InferenceSession
    from skl2onnx import to_onnx
    from sklearn.ensemble import RandomForestClassifier, VotingClassifier
    from sklearn.impute import SimpleImputer
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import OrdinalEncoder

    dataset = openml.datasets.get_dataset(31)
    X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)

    model = VotingClassifier(
        estimators=[
            (
                "est",
                Pipeline(
                    steps=[
                        ("imputer", SimpleImputer(strategy="most_frequent")),
                        ("encoder", OrdinalEncoder()),
                        ("rf", RandomForestClassifier(n_estimators=10)),
                    ],
                ),
            ),
        ],
    )
    model.fit(X, y)
    schema = guess_schema_from_data(X)

    # Errors here
    onnx_model = to_onnx(model=model, initial_types=schema)

    sess = InferenceSession(onnx_model.SerializeToString())
    inputs = {c: X[c].to_numpy().reshape((-1, 1)) for c in X.columns}
    got = sess.run(None, inputs)

    print(got)


if __name__ == "__main__":
    main()
Traceback
/blank/.venv/lib/python3.10/site-packages/openml/datasets/functions.py:438: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.
  warnings.warn(
Traceback (most recent call last):
  File "/blank/onnx-test.py", line 45, in <module>
    main()
  File "/blank/onnx-test.py", line 35, in main
    onnx_model = to_onnx(model=model, initial_types=schema)
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/convert.py", line 306, in to_onnx
    return convert_sklearn(
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/convert.py", line 208, in convert_sklearn
    onnx_model = convert_topology(
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/common/_topology.py", line 1532, in convert_topology
    topology.convert_operators(container=container, verbose=verbose)
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/common/_topology.py", line 1348, in convert_operators
    self.call_shape_calculator(operator)
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/common/_topology.py", line 1163, in call_shape_calculator
    operator.infer_types()
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/common/_topology.py", line 652, in infer_types
    shape_calc(self)
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/shape_calculators/voting_classifier.py", line 8, in voting_classifier_shape_calculator
    return _calculate_linear_classifier_output_shapes(
  File "/blank/.venv/lib/python3.10/site-packages/skl2onnx/common/shape_calculator.py", line 43, in _calculate_linear_classifier_output_shapes
    check_input_and_output_numbers(
  File "/blank/.venv/lib/python3.10/site-packages/onnxconverter_common/utils.py", line 295, in check_input_and_output_numbers
    raise RuntimeError(
RuntimeError: For operator SklearnVotingClassifier (type: SklearnVotingClassifier), at most 1 input(s) is(are) supported but we got 20 input(s) which are
['checking_status', 'duration', 'credit_history', 'purpose', 'credit_amount', 'savings_status', 'employment', 'installment_commitment',
'personal_status', 'other_parties', 'residence_since', 'property_magnitude', 'age', 'other_payment_plans', 'housing',
'existing_credits', 'job', 'num_dependents', 'own_telephone', 'foreign_worker']
@xadupre
Copy link
Collaborator

xadupre commented Oct 3, 2023

The converter does expect to have one tensor as input. You can use a ColumnTransformer to concatenate all columns into a single one. Then, I put the encoder first as onnx only support numerical values for Imputer. This is the pipeline validated in PR #1030.

model = Pipeline(
    steps=[
        (
            "concat",
            ColumnTransformer(
                [("concat", "passthrough", list(range(X.shape[1])))],
                sparse_threshold=0,
            ),
        ),
        (
            "voting",
            VotingClassifier(
                flatten_transform=False,
                estimators=[
                    (
                        "est",
                        Pipeline(
                            steps=[
                                # This encoder is placed before SimpleImputer because
                                # onnx does not support text for Imputer
                                ("encoder", OrdinalEncoder()),
                                (
                                    "imputer",
                                    SimpleImputer(strategy="most_frequent"),
                                ),
                                (
                                    "rf",
                                    RandomForestClassifier(
                                        n_estimators=4,
                                        max_depth=4,
                                        random_state=0,
                                    ),
                                ),
                            ],
                        ),
                    ),
                ],
            ),
        ),
    ]
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants