Skip to content

[quant] Adding quantized mul kernel #24444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

z-a-f
Copy link

@z-a-f z-a-f commented Aug 15, 2019

Stack from ghstack:

Pull Request resolved: #24444

Differential Revision: D16844824

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Aug 15, 2019
z-a-f pushed a commit that referenced this pull request Aug 15, 2019
ghstack-source-id: c2656e9
Pull Request resolved: #24444
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

auto iter = TensorIterator::binary_op(out, self, other);
AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul", [&]() {
cpu_kernel(iter, [&](scalar_t a, scalar_t b) -> scalar_t {
const auto da = at::dequantize_val(self_scale, self_zero_point, a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have a requantize_val API?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was afraid of the error contribution. I will try it out!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't think that will help at all. The problem is that the multiplication doen't really work well with quantized types. My rationale is that if we requantize the values, we still need to compute qc = round(a*b/s + z), where a, b are floating point multiplicands, and s, z are the output params. We are given qa = round(a/s + z), qb = round(b/s + z), which are the requantized values. We are given qa, qb and need to compute qc I don't see a way on how to do it with out dequantizing them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, also noticed you are using const here, Tensor is a pointer so const is not necessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a scalar_t value, not tensor

"Only per tensor quantization is suuported in Mul.");
TORCH_CHECK(qa.qscheme() == qb.qscheme(),
"Both inputs to Mul must have the same quantization shceme.");
TORCH_CHECK(qa.numel() == qb.numel(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we not support broadcast? For float, this works:

z = torch.rand(3,4,5)
v = torch.rand(3,1,1)
q = torch.mul(z,v)

and we should aim to match the same.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works because there is a broadcasting logic there. It is in the todo list to implement the broadcasting for the quantized ops, but not part of this PR.

@z-a-f z-a-f changed the title Adding quantized mul kernel [quant] Adding quantized mul kernel Aug 21, 2019
z-a-f pushed a commit that referenced this pull request Aug 21, 2019
@zou3519 zou3519 deleted the gh/zafartahirov/33/head branch August 22, 2019 23:54
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 23, 2019
Summary: Pull Request resolved: pytorch/pytorch#24444

Test Plan: Imported from OSS

Differential Revision: D16844824

Pulled By: zafartahirov

fbshipit-source-id: 626c40e1cad8329c3d8517156f2d36d7a7472890
@facebook-github-bot
Copy link
Contributor

@zafartahirov merged this pull request in 9764c2e.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn oncall: quantization Quantization support in PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants