forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Add utility for checking bank conflict of shared memory #2029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h> | ||
|
||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h> | ||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h> | ||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h> | ||
#include <torch/csrc/jit/codegen/cuda/type.h> | ||
|
||
#include <unordered_set> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace fuser { | ||
namespace cuda { | ||
|
||
namespace { | ||
|
||
bool isSmemTensorIndex(Val* value) { | ||
return value->isA<kir::TensorIndex>() && | ||
value->as<kir::TensorIndex>()->view()->getMemoryType() == | ||
MemoryType::Shared; | ||
} | ||
|
||
int64_t getVectorizeSize(kir::TensorIndex* ti) { | ||
for (auto id : ti->view()->domain()->domain()) { | ||
if (!isParallelTypeVectorize(id->getParallelType())) { | ||
continue; | ||
} | ||
|
||
ExpressionEvaluator expr_eval(id->fusion()); | ||
auto vector_size_optional = expr_eval.evaluate(id->extent()); | ||
|
||
TORCH_INTERNAL_ASSERT( | ||
vector_size_optional.has_value(), | ||
"Could not evaluate constant value bound to vectorized dim."); | ||
|
||
return vector_size_optional->as<int64_t>(); | ||
} | ||
return 1; | ||
} | ||
|
||
inline int64_t getPhaseSize(int64_t word_size_bytes) { | ||
if (word_size_bytes == 16) { | ||
return 8; | ||
} | ||
if (word_size_bytes == 8) { | ||
return 16; | ||
} | ||
return 32; | ||
} | ||
|
||
std::vector<int64_t> evaluateAddressesOnFirstPhase( | ||
kir::TensorIndex* ti, | ||
const std::vector<kir::ForLoop*>& for_loops) { | ||
std::vector<int64_t> addresses; | ||
const auto word_size_bytes = | ||
dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); | ||
int64_t phase_size = getPhaseSize(word_size_bytes); | ||
|
||
for (auto tidx : c10::irange(phase_size)) { | ||
int64_t index = 0; | ||
ExpressionEvaluator expr_eval(ti->fusion()); | ||
for (auto fl : for_loops) { | ||
if (fl->index()->isA<NamedScalar>() && | ||
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") { | ||
expr_eval.bind(fl->index(), tidx); | ||
} else { | ||
expr_eval.bind(fl->index(), 0); | ||
} | ||
} | ||
for (auto ind : ti->indices()) { | ||
index += expr_eval.evaluate(ind)->as<int64_t>(); | ||
} | ||
addresses.emplace_back(index * word_size_bytes); | ||
} | ||
return addresses; | ||
} | ||
|
||
int getConflictWays(const std::vector<int64_t>& addresses) { | ||
std::unordered_set<int64_t> words_by_bank[32]; | ||
for (auto addr : addresses) { | ||
int64_t word = addr / 4; | ||
int64_t bank = word % 32; | ||
words_by_bank[bank].insert(word); | ||
} | ||
int conflict = 1; | ||
for (const auto& words : words_by_bank) { | ||
conflict = std::max<int>(conflict, words.size()); | ||
} | ||
return conflict; | ||
} | ||
|
||
} // namespace | ||
|
||
class BankConflictInfo : public kir::IrVisitor { | ||
public: | ||
static std::unordered_map<const Expr*, std::pair<int, int>> get( | ||
const std::vector<Expr*>& exprs) { | ||
return BankConflictInfo(exprs).bank_conflict_info_; | ||
} | ||
|
||
private: | ||
BankConflictInfo(const std::vector<Expr*>& exprs) { | ||
handle(exprs); | ||
} | ||
|
||
using kir::IrVisitor::handle; | ||
|
||
void handle(Expr* expr) final { | ||
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { | ||
kir::IrVisitor::handle(expr); | ||
return; | ||
} | ||
|
||
if (expr->isA<UnaryOp>()) { | ||
auto uop = expr->as<UnaryOp>(); | ||
if (uop->getUnaryOpType() != UnaryOpType::Set) { | ||
return; | ||
} | ||
std::pair<int, int> conflict_ways{0, 0}; | ||
if (isSmemTensorIndex(uop->in())) { | ||
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( | ||
uop->in()->as<kir::TensorIndex>(), for_loops_)); | ||
} | ||
if (isSmemTensorIndex(uop->out())) { | ||
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( | ||
uop->out()->as<kir::TensorIndex>(), for_loops_)); | ||
} | ||
if (conflict_ways.first > 1 || conflict_ways.second > 1) { | ||
bank_conflict_info_[expr] = conflict_ways; | ||
} | ||
} else if (expr->isA<LoadStoreOp>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if |
||
auto ldst = expr->as<LoadStoreOp>(); | ||
std::pair<int, int> conflict_ways{0, 0}; | ||
if (isSmemTensorIndex(ldst->in())) { | ||
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( | ||
ldst->in()->as<kir::TensorIndex>(), for_loops_)); | ||
} | ||
if (isSmemTensorIndex(ldst->out())) { | ||
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( | ||
ldst->out()->as<kir::TensorIndex>(), for_loops_)); | ||
} | ||
if (conflict_ways.first > 1 || conflict_ways.second > 1) { | ||
bank_conflict_info_[expr] = conflict_ways; | ||
} | ||
} | ||
} | ||
|
||
std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_; | ||
}; | ||
|
||
std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo( | ||
kir::Kernel* kernel) { | ||
return BankConflictInfo::get(kernel->topLevelExprs()); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace fuser | ||
} // namespace jit | ||
} // namespace torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h> | ||
#include <torch/csrc/jit/codegen/cuda/kernel.h> | ||
|
||
#include <unordered_map> | ||
#include <utility> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace fuser { | ||
namespace cuda { | ||
|
||
// for more info on shared memory access see page 54-72 of: | ||
// https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf | ||
|
||
// Warning: The bank confliction checking utility here is not a replacement of | ||
// nsight compute. This utility currently has the following assumptions and | ||
// limitations: | ||
// | ||
// 1. This utility assumes that `blockDim.x` is large enough to hold one phase | ||
// 2. This utility assumes that the address only depends on loop variables | ||
// (there can not be a thing like `T0.stride[0]`, `blockDim.x`) | ||
// 3. This utility assumes that the data of the tensor is accessed by | ||
// `T0[index]`, where `index` is the one stored in the `TensorIndex` | ||
// object. | ||
// 4. This utility only checks the first iteration, and the start of all | ||
// loop variables are assumed to be `0` (if we have something like | ||
// `T1_s[tidx, 5]`, then different iterations should have different | ||
// results, which this utility will not be able to handle all of them now) | ||
// 5. This utility assumes that all tensors are independent, which means: | ||
// 5.1 All shared memory tensors are allocated starting from a multiple of | ||
// 4*32 bytes | ||
// 5.2 The only source of bank confliction is from within a tensor. | ||
// There is no bank conflict between different tensors. | ||
// | ||
// Also note that this utility will not provide accurate estimation if the above | ||
// assumptions are satisfied | ||
|
||
std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo( | ||
kir::Kernel* kernel); | ||
|
||
} // namespace cuda | ||
} // namespace fuser | ||
} // namespace jit | ||
} // namespace torch |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this assume
threadIdx.y/z
don't matter for bank conflicts? What about if we have, e.g.,blockDim.x == 1 && blockDim.y == 32
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am currently making this assumption. But I can add a new overload that takes a launch parameter, which will lift this assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with the current state as long as the assumptions are made clear. Not sure how important to make this more flexible at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also OK with this assumption in this PR. I will write a few followup PRs to lift some of these assumptions.