Skip to content

Initializer duplication method in QDQQuantizer ignores existing value_info tensor with same name #24705

Open
@skywall

Description

@skywall

Describe the issue

Method _dup_initializer() in QDQQuantizer checks only initializer list when it's generating new initializer suffix value.

Let's consider initializer with name b is being duplicated. Method will create new initializer b0 without check if there isn't already value info with the same name. This causes both tensors being added and at the end corrupted model - intializer is used instead of value info during Q/DQ nodes creation.

Couple of conditions has to be met to reproduce this:

  1. Model must have bias tensor shared by multiple consumers.
  2. There has to be value_info / initializer with same name - value_info has to also have numerical suffix.

Model xception is affected by this problem.

To reproduce

Un/comment lines with inter_tensor_name to demonstrate valid/invalid case:

import numpy as np
import onnx.helper
import onnxruntime
from onnx import TensorProto
from onnxruntime import quantization


def test_conv_():
    # inter_tensor_name = "inter" # <-- PASS
    inter_tensor_name = "b0"  # <-- CRASH

    kernel_shape = [2, 2]
    weight_shape = [10, 10] + kernel_shape
    bias_shape = [weight_shape[0]]
    weights_data_1 = np.random.random(np.prod(weight_shape)).reshape(weight_shape).astype(np.float32)
    bias_data = np.random.random(bias_shape).astype(np.float32)

    weight_shape_2 = [10, 10] + kernel_shape
    weights_data_2 = np.random.random(np.prod(weight_shape_2)).reshape(weight_shape_2).astype(np.float32)

    input_shape = [1, weight_shape[1], 5, 5]
    graph = onnx.helper.make_graph(
        [
            onnx.helper.make_node("Conv", ["x", "w", "b"], [inter_tensor_name], kernel_shape=kernel_shape),
            onnx.helper.make_node("Conv", [inter_tensor_name, "ww", "b"], ["output"], kernel_shape=kernel_shape),
        ],
        "ConvTest",
        [onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape)],
        [onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ())],
        [
            onnx.helper.make_tensor("w", TensorProto.FLOAT, weight_shape, weights_data_1),
            onnx.helper.make_tensor("b", TensorProto.FLOAT, bias_shape, bias_data),
            onnx.helper.make_tensor("ww", TensorProto.FLOAT, weight_shape_2, weights_data_2),
        ],
        value_info=[onnx.helper.make_tensor_value_info(inter_tensor_name, TensorProto.FLOAT, ())]
    )

    onnx_model = onnx.helper.make_model(graph)
    model_input = {"x": np.random.random(np.prod(input_shape)).reshape(input_shape).astype(np.float32)}
    sess = onnxruntime.InferenceSession(onnx_model.SerializeToString())
    _ = sess.run(None, model_input)

    class RandomData:

        def __init__(self):
            self.cnt = 0

        def get_next(self) -> dict | None:
            if self.cnt == 3:
                return None

            self.cnt += 1
            return {
                "x": np.random.random(np.prod(input_shape)).reshape(input_shape).astype(np.float32)
            }

    quantization.quantize_static(onnx_model, "quantized.onnx", RandomData())

    quantized_model = onnx.load_model("quantized.onnx")
    sess = onnxruntime.InferenceSession(quantized_model.SerializeToString())
    _ = sess.run(None, model_input)

Urgency

No response

Platform

Linux

OS Version

WSL - Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    quantizationissues related to quantization

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions