-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Introduce AT_FORALL_QINTS #22931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce AT_FORALL_QINTS #22931
Conversation
Differential Revision: [D16467985](https://our.internmc.facebook.com/intern/diff/D16467985)
@@ -1800,7 +1800,8 @@ inline bool is_quantized(Tensor self) { | |||
return static_cast<T*>(this->data_ptr()); \ | |||
} | |||
|
|||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) | |||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_CAST) | |||
AT_FORALL_QINTS(DEFINE_CAST) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to cast them to quint8* or uint8*?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel uint8*
might be more useful for people
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya, I think that's the correct solution but outside of the scope of this PR, which is just trying to rationalize the AT_FORALL macros. So I'm just trying to match up with this one:
pytorch/aten/src/ATen/Dispatch.h
Lines 256 to 269 in 34f5356
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ | |
[&] { \ | |
const auto& SCALAR_TYPE C10_UNUSED = TYPE; \ | |
switch (TYPE) { \ | |
AT_QINT_PRIVATE_CASE_TYPE( \ | |
at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ | |
AT_QINT_PRIVATE_CASE_TYPE( \ | |
at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \ | |
AT_QINT_PRIVATE_CASE_TYPE( \ | |
at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \ | |
default: \ | |
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ | |
} \ | |
}() |
so I should have actually called this AT_FORALL_QINT_TYPES, but I fix that up in a later PR in this stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed #23440 for addressing that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary: Pull Request resolved: pytorch/pytorch#22931 Test Plan: Imported from OSS Differential Revision: D16467985 Pulled By: gchanan fbshipit-source-id: 3925fc96a641e66b92fa65c542a2a23190c915a5
Stack from ghstack:
Differential Revision: D16467985