-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[quant] Adding quantized mul kernel #24444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
auto iter = TensorIterator::binary_op(out, self, other); | ||
AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul", [&]() { | ||
cpu_kernel(iter, [&](scalar_t a, scalar_t b) -> scalar_t { | ||
const auto da = at::dequantize_val(self_scale, self_zero_point, a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have a requantize_val API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was afraid of the error contribution. I will try it out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't think that will help at all. The problem is that the multiplication doen't really work well with quantized types. My rationale is that if we requantize the values, we still need to compute qc = round(a*b/s + z)
, where a, b are floating point multiplicands, and s, z are the output params. We are given qa = round(a/s + z), qb = round(b/s + z)
, which are the requantized values. We are given qa, qb
and need to compute qc
I don't see a way on how to do it with out dequantizing them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, also noticed you are using const here, Tensor is a pointer so const is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a scalar_t
value, not tensor
"Only per tensor quantization is suuported in Mul."); | ||
TORCH_CHECK(qa.qscheme() == qb.qscheme(), | ||
"Both inputs to Mul must have the same quantization shceme."); | ||
TORCH_CHECK(qa.numel() == qb.numel(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we not support broadcast? For float, this works:
z = torch.rand(3,4,5)
v = torch.rand(3,1,1)
q = torch.mul(z,v)
and we should aim to match the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works because there is a broadcasting logic there. It is in the todo list to implement the broadcasting for the quantized ops, but not part of this PR.
Pull Request resolved: #24444 Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Pull Request resolved: #24444 Differential Revision: [D16844824](https://our.internmc.facebook.com/intern/diff/D16844824)
Summary: Pull Request resolved: pytorch/pytorch#24444 Test Plan: Imported from OSS Differential Revision: D16844824 Pulled By: zafartahirov fbshipit-source-id: 626c40e1cad8329c3d8517156f2d36d7a7472890
@zafartahirov merged this pull request in 9764c2e. |
Stack from ghstack:
Pull Request resolved: #24444
Differential Revision: D16844824