-
Notifications
You must be signed in to change notification settings - Fork 7
Rework vectorized load/stores. #1457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Added some questions and comments.
loadGeneric<scalar_t, vec_size>(to, from); | ||
break; | ||
case 8: { | ||
uint2 const data = *reinterpret_cast<uint2 const*>(from); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity, is this copy necessary for using the inline assembly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might be able to do const&
but the reinterpret cast is necessary.
auto out_tv = uop->out()->as<kir::TensorIndex>()->view(); | ||
if (uop->in()->isScalar()) { | ||
if (out_tv->getMemoryType() == MemoryType::Local) { | ||
// Vectorized intiialization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type: initialization
} | ||
|
||
template <typename scalar_t, int vec_size> | ||
__device__ void loadLocalToGlobal(scalar_t* to, scalar_t* from) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to have variations for shared mem? Or does the compiler property take care of those cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably will want shared memory versions, maybe ldg.sts versions, it's open to expansion for sure. I don't have a case where smem doesn't do the right thing, but we're really not using that path much atm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have seen auto-vectorize on smem access but not yet seen any de-vectorize. So I guess loadGeneric
should be good for a while.
Also smem pointer would require a cvta
inst and an extra register for that output, which I thought might affect register allocation if we use asm to do it.
if (alias_tv->getMemoryType() == MemoryType::Local && | ||
va.find(alias_tv) != va.end()) { | ||
indent() << "auto& " << varName(tv) << " = " << varName(alias_tv) | ||
<< ";\n"; | ||
} else { | ||
indent() << buffer_dtype << "* " << varName(tv) << " = " | ||
<< varName(alias_tv) << ";\n"; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comment describing why the local and vectorized case needs the special handling would be great.
I assume we could always generate auto& ...
. Would there be any preference to have the pointer style of code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think any explicit benefit, it helps me look through the code quickly just because I know the difference, but agreed we could just do auto&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. Some minor discussions.
|
||
// Used for vectorized allocations that are not in registers | ||
template <typename scalar_t, int vec_size> | ||
void arraySet(scalar_t* buff, scalar_t val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we need specializations for vectorized initialization? We could also rely on compiler's auto-vectorization pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't explicitly check as it is unlikely to be a perf bottleneck for memory bound ops. If it's necessary we could definitely do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I guess it would be if we have bank conflicts in initialization or limited by Inst cache. I will add them if I see anything limited by that.
return false; | ||
} | ||
|
||
// Shared memory is all aligned to 128 bits, local memory might not be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious do we already have 128b alignment for all smem or it's in a different PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we already align it, is it not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like TOT aligns to data type size:
I modified the static version in #1439 .
} | ||
|
||
// Shared memory is all aligned to 128 bits, local memory might not be | ||
if (this_tv->getMemoryType() == MemoryType::Local && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to move this to line 922. This seems to apply for outer sharing as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shmsong can you double check I did this right?
Fixed instances where our vectorized support wasn't actually generating vectorized sass. Tried to make the usage of Array more explicit in allocation rather than dynamic casting to it.