Skip to content

Add utility for checking bank conflict of shared memory #2029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_validation.cpp",
"torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <ATen/core/LegacyTypeDispatch.h>
Expand Down Expand Up @@ -251,6 +252,27 @@ void FusionExecutor::compileFusion(
kernel->print();
}

if (isDebugDumpEnabled(DebugDumpOption::BankConflictInfo)) {
auto bank_conflict_info = getBankConflictInfo(kernel);
if (bank_conflict_info.empty()) {
std::cout << "===== No bank confliction =====" << std::endl;
} else {
std::cout << "======= Bank confliction =======" << std::endl;
for (auto info : bank_conflict_info) {
std::cout << "Expr: " << info.first->toString() << std::endl;
auto conflict = info.second;
if (conflict.first > 1) {
std::cout << "input conflict: " << conflict.first << " way, ";
}
if (conflict.second > 1) {
std::cout << "output conflict: " << conflict.second << " way";
}
std::cout << std::endl;
}
std::cout << "================================" << std::endl;
}
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
const auto structured_code = getStructuredCode(kernel_code_);

Expand Down
21 changes: 16 additions & 5 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) {
TORCH_CHECK(
value->definition() == nullptr,
"Tried to bind to a value that is computed in the fusion IR");
known_values_[value] = concrete_value;
if (value->isA<NamedScalar>()) {
known_named_scalars_[value->as<NamedScalar>()->name()] = concrete_value;
} else {
known_values_[value] = concrete_value;
}
}

c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) {
Expand Down Expand Up @@ -88,7 +92,7 @@ void ExpressionEvaluator::print() const {
c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) {
TORCH_INTERNAL_ASSERT(
value->isAnInt() || value->isADouble(),
"Expression Evaluation does not support values other than integers at this time.");
"Expression Evaluation does not support values other than integers/doubles at this time.");

if (value->getValType().value() == ValType::Scalar) {
if (value->isAnInt() && value->as<Int>()->value().has_value()) {
Expand All @@ -99,9 +103,16 @@ c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) {
}
}

const auto it = known_values_.find(value);
return it != known_values_.end() ? c10::optional<IntOrDouble>(it->second)
: c10::nullopt;
if (value->isA<NamedScalar>()) {
const auto it = known_named_scalars_.find(value->as<NamedScalar>()->name());
return it != known_named_scalars_.end()
? c10::optional<IntOrDouble>(it->second)
: c10::nullopt;
} else {
const auto it = known_values_.find(value);
return it != known_values_.end() ? c10::optional<IntOrDouble>(it->second)
: c10::nullopt;
}
}

void ExpressionEvaluator::handle(UnaryOp* uop) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {

void handle(UnaryOp*) final;
void handle(BinaryOp*) final;
// TODO: handle swizzle

private:
std::unordered_map<const Val*, IntOrDouble> known_values_;
std::unordered_map<std::string, IntOrDouble> known_named_scalars_;
Fusion* fusion_ = nullptr;
FusionPrecomputedValues* evaluator_precomputed_values_ = nullptr;
};
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -339,6 +340,20 @@ void Fusion::printKernel(DataType index_type) {
std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel());
}

std::unordered_map<std::string, std::pair<int, int>> Fusion::bankConflictInfo(
DataType index_type) {
GpuLower lower(this, index_type);
auto kernel = lower.kernel();
auto info = getBankConflictInfo(kernel);
// The container of exprs goes out of scope, so we return a map of string here
std::unordered_map<std::string, std::pair<int, int>> result;
result.reserve(info.size());
for (auto i : info) {
result[i.first->toString()] = i.second;
}
return result;
}

void Fusion::printMath(bool from_outputs_only) {
FUSER_PERF_SCOPE("Fusion::printMath");

Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! Lower the fusion and print a kernel
void printKernel(DataType index_type = DataType::Int);

//! Lower the fusion and evaluate bank conflict info
std::unordered_map<std::string, std::pair<int, int>> bankConflictInfo(
DataType index_type = DataType::Int);

//! Return a list of topologically sorted expressions. This only includes
//! exprs required to genereate registered outputs.
std::vector<Expr*> exprs();
Expand Down
159 changes: 159 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>

#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
#include <torch/csrc/jit/codegen/cuda/type.h>

#include <unordered_set>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

namespace {

bool isSmemTensorIndex(Val* value) {
return value->isA<kir::TensorIndex>() &&
value->as<kir::TensorIndex>()->view()->getMemoryType() ==
MemoryType::Shared;
}

int64_t getVectorizeSize(kir::TensorIndex* ti) {
for (auto id : ti->view()->domain()->domain()) {
if (!isParallelTypeVectorize(id->getParallelType())) {
continue;
}

ExpressionEvaluator expr_eval(id->fusion());
auto vector_size_optional = expr_eval.evaluate(id->extent());

TORCH_INTERNAL_ASSERT(
vector_size_optional.has_value(),
"Could not evaluate constant value bound to vectorized dim.");

return vector_size_optional->as<int64_t>();
}
return 1;
}

inline int64_t getPhaseSize(int64_t word_size_bytes) {
if (word_size_bytes == 16) {
return 8;
}
if (word_size_bytes == 8) {
return 16;
}
return 32;
}

std::vector<int64_t> evaluateAddressesOnFirstPhase(
kir::TensorIndex* ti,
const std::vector<kir::ForLoop*>& for_loops) {
std::vector<int64_t> addresses;
const auto word_size_bytes =
dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti);
int64_t phase_size = getPhaseSize(word_size_bytes);

for (auto tidx : c10::irange(phase_size)) {
int64_t index = 0;
ExpressionEvaluator expr_eval(ti->fusion());
for (auto fl : for_loops) {
if (fl->index()->isA<NamedScalar>() &&
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") {
expr_eval.bind(fl->index(), tidx);
} else {
expr_eval.bind(fl->index(), 0);
}
Comment on lines +63 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume threadIdx.y/z don't matter for bank conflicts? What about if we have, e.g., blockDim.x == 1 && blockDim.y == 32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am currently making this assumption. But I can add a new overload that takes a launch parameter, which will lift this assumption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with the current state as long as the assumptions are made clear. Not sure how important to make this more flexible at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also OK with this assumption in this PR. I will write a few followup PRs to lift some of these assumptions.

}
for (auto ind : ti->indices()) {
index += expr_eval.evaluate(ind)->as<int64_t>();
}
addresses.emplace_back(index * word_size_bytes);
}
return addresses;
}

int getConflictWays(const std::vector<int64_t>& addresses) {
std::unordered_set<int64_t> words_by_bank[32];
for (auto addr : addresses) {
int64_t word = addr / 4;
int64_t bank = word % 32;
words_by_bank[bank].insert(word);
}
int conflict = 1;
for (const auto& words : words_by_bank) {
conflict = std::max<int>(conflict, words.size());
}
return conflict;
}

} // namespace

class BankConflictInfo : public kir::IrVisitor {
public:
static std::unordered_map<const Expr*, std::pair<int, int>> get(
const std::vector<Expr*>& exprs) {
return BankConflictInfo(exprs).bank_conflict_info_;
}

private:
BankConflictInfo(const std::vector<Expr*>& exprs) {
handle(exprs);
}

using kir::IrVisitor::handle;

void handle(Expr* expr) final {
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
kir::IrVisitor::handle(expr);
return;
}

if (expr->isA<UnaryOp>()) {
auto uop = expr->as<UnaryOp>();
if (uop->getUnaryOpType() != UnaryOpType::Set) {
return;
}
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(uop->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
uop->in()->as<kir::TensorIndex>(), for_loops_));
}
if (isSmemTensorIndex(uop->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
uop->out()->as<kir::TensorIndex>(), for_loops_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
}
} else if (expr->isA<LoadStoreOp>()) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if ld.matrix is supported correctly here. Will need to dig deeper to understand how it works. I will update this in later PR.

auto ldst = expr->as<LoadStoreOp>();
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(ldst->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->in()->as<kir::TensorIndex>(), for_loops_));
}
if (isSmemTensorIndex(ldst->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->out()->as<kir::TensorIndex>(), for_loops_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
}
}
}

std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_;
};

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel) {
return BankConflictInfo::get(kernel->topLevelExprs());
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
46 changes: 46 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>

#include <unordered_map>
#include <utility>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// for more info on shared memory access see page 54-72 of:
// https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf

// Warning: The bank confliction checking utility here is not a replacement of
// nsight compute. This utility currently has the following assumptions and
// limitations:
//
// 1. This utility assumes that `blockDim.x` is large enough to hold one phase
// 2. This utility assumes that the address only depends on loop variables
// (there can not be a thing like `T0.stride[0]`, `blockDim.x`)
// 3. This utility assumes that the data of the tensor is accessed by
// `T0[index]`, where `index` is the one stored in the `TensorIndex`
// object.
// 4. This utility only checks the first iteration, and the start of all
// loop variables are assumed to be `0` (if we have something like
// `T1_s[tidx, 5]`, then different iterations should have different
// results, which this utility will not be able to handle all of them now)
// 5. This utility assumes that all tensors are independent, which means:
// 5.1 All shared memory tensors are allocated starting from a multiple of
// 4*32 bytes
// 5.2 The only source of bank confliction is from within a tensor.
// There is no bank conflict between different tensors.
//
// Also note that this utility will not provide accurate estimation if the above
// assumptions are satisfied

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel);

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
Loading