Skip to content

Conversation

csarofeen
Copy link
Owner

Fixed instances where our vectorized support wasn't actually generating vectorized sass. Tried to make the usage of Array more explicit in allocation rather than dynamic casting to it.

@csarofeen csarofeen requested review from naoyam and shmsong February 10, 2022 19:36
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Added some questions and comments.

loadGeneric<scalar_t, vec_size>(to, from);
break;
case 8: {
uint2 const data = *reinterpret_cast<uint2 const*>(from);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, is this copy necessary for using the inline assembly?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might be able to do const& but the reinterpret cast is necessary.

auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
if (uop->in()->isScalar()) {
if (out_tv->getMemoryType() == MemoryType::Local) {
// Vectorized intiialization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type: initialization

}

template <typename scalar_t, int vec_size>
__device__ void loadLocalToGlobal(scalar_t* to, scalar_t* from) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to have variations for shared mem? Or does the compiler property take care of those cases?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably will want shared memory versions, maybe ldg.sts versions, it's open to expansion for sure. I don't have a case where smem doesn't do the right thing, but we're really not using that path much atm.

Copy link

@shmsong shmsong Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen auto-vectorize on smem access but not yet seen any de-vectorize. So I guess loadGeneric should be good for a while.

Also smem pointer would require a cvta inst and an extra register for that output, which I thought might affect register allocation if we use asm to do it.

Comment on lines 1307 to 1314
if (alias_tv->getMemoryType() == MemoryType::Local &&
va.find(alias_tv) != va.end()) {
indent() << "auto& " << varName(tv) << " = " << varName(alias_tv)
<< ";\n";
} else {
indent() << buffer_dtype << "* " << varName(tv) << " = "
<< varName(alias_tv) << ";\n";
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comment describing why the local and vectorized case needs the special handling would be great.

I assume we could always generate auto& .... Would there be any preference to have the pointer style of code?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any explicit benefit, it helps me look through the code quickly just because I know the difference, but agreed we could just do auto&

Copy link

@shmsong shmsong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Some minor discussions.


// Used for vectorized allocations that are not in registers
template <typename scalar_t, int vec_size>
void arraySet(scalar_t* buff, scalar_t val) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need specializations for vectorized initialization? We could also rely on compiler's auto-vectorization pass.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't explicitly check as it is unlikely to be a perf bottleneck for memory bound ops. If it's necessary we could definitely do it.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I guess it would be if we have bank conflicts in initialization or limited by Inst cache. I will add them if I see anything limited by that.

return false;
}

// Shared memory is all aligned to 128 bits, local memory might not be
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious do we already have 128b alignment for all smem or it's in a different PR?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already align it, is it not the case?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like TOT aligns to data type size:

https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/codegen.cpp#L1293-L1308

I modified the static version in #1439 .

}

// Shared memory is all aligned to 128 bits, local memory might not be
if (this_tv->getMemoryType() == MemoryType::Local &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to move this to line 922. This seems to apply for outer sharing as well.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shmsong can you double check I did this right?

@csarofeen csarofeen merged commit 44e8c15 into devel Feb 11, 2022
@csarofeen csarofeen deleted the vectorize_rework branch May 7, 2022 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants