Skip to content

Commit f5bca33

Browse files
authored
Bank conflict checker improvements (#2032)
1 parent d2ca7e3 commit f5bca33

File tree

5 files changed

+327
-28
lines changed

5 files changed

+327
-28
lines changed

torch/csrc/jit/codegen/cuda/expr_evaluator.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) {
6161
}
6262
}
6363

64+
void ExpressionEvaluator::bind(
65+
const std::string& name,
66+
const IntOrDouble& concrete_value) {
67+
known_named_scalars_[name] = concrete_value;
68+
}
69+
6470
c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) {
6571
if (evaluator_precomputed_values_ != nullptr) {
6672
return toOptionalIntOrDouble(

torch/csrc/jit/codegen/cuda/expr_evaluator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <c10/util/Optional.h>
99

10+
#include <string>
1011
#include <unordered_map>
1112

1213
namespace torch {
@@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
3031
//! Bind a concrete value to an IR variable
3132
void bind(Val* value, const IntOrDouble& concrete_value);
3233

34+
//! Bind a concrete value to a named scalar
35+
void bind(const std::string& name, const IntOrDouble& concrete_value);
36+
3337
//! Try to evaluate a Fusion IR value
3438
c10::optional<IntOrDouble> evaluate(Val* value);
3539

torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp

Lines changed: 190 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
22

3+
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
34
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
45
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
56
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
@@ -48,23 +49,78 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) {
4849
return 32;
4950
}
5051

52+
bool isThreadIdx(const std::string& name) {
53+
return name == "threadIdx.x" || name == "threadIdx.y" ||
54+
name == "threadIdx.z";
55+
}
56+
57+
bool isBlockIdx(const std::string& name) {
58+
return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z";
59+
}
60+
61+
bool isBlockDim(const std::string& name) {
62+
return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z";
63+
}
64+
65+
bool isGridDim(const std::string& name) {
66+
return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z";
67+
}
68+
69+
ParallelType getParallelType(const std::string& name) {
70+
if (name == "threadIdx.x") {
71+
return ParallelType::TIDx;
72+
} else if (name == "threadIdx.y") {
73+
return ParallelType::TIDy;
74+
} else if (name == "threadIdx.z") {
75+
return ParallelType::TIDz;
76+
} else if (name == "blockIdx.x") {
77+
return ParallelType::BIDx;
78+
} else if (name == "blockIdx.y") {
79+
return ParallelType::BIDy;
80+
} else if (name == "blockIdx.z") {
81+
return ParallelType::BIDz;
82+
}
83+
TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
84+
}
85+
5186
std::vector<int64_t> evaluateAddressesOnFirstPhase(
5287
kir::TensorIndex* ti,
53-
const std::vector<kir::ForLoop*>& for_loops) {
88+
const std::vector<kir::ForLoop*>& for_loops,
89+
c10::optional<LaunchParams> launch_params,
90+
const ExpressionEvaluator& expr_eval_common) {
5491
std::vector<int64_t> addresses;
5592
const auto word_size_bytes =
5693
dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti);
5794
int64_t phase_size = getPhaseSize(word_size_bytes);
5895

59-
for (auto tidx : c10::irange(phase_size)) {
96+
if (launch_params.has_value()) {
97+
phase_size = std::min<int64_t>(phase_size, launch_params->nThreads());
98+
}
99+
100+
for (int64_t linear_tidx : c10::irange(phase_size)) {
101+
int64_t tidx = linear_tidx;
102+
int64_t tidy = 0;
103+
int64_t tidz = 0;
104+
if (launch_params.has_value()) {
105+
tidy = tidx / launch_params->bdimx();
106+
tidx = tidx % launch_params->bdimx();
107+
tidz = tidy / launch_params->bdimy();
108+
tidy = tidy % launch_params->bdimy();
109+
}
60110
int64_t index = 0;
61-
ExpressionEvaluator expr_eval(ti->fusion());
111+
// make a copy of the expression evaluator
112+
ExpressionEvaluator expr_eval = expr_eval_common;
113+
expr_eval.bind("threadIdx.x", tidx);
114+
expr_eval.bind("threadIdx.y", tidy);
115+
expr_eval.bind("threadIdx.z", tidz);
62116
for (auto fl : for_loops) {
63-
if (fl->index()->isA<NamedScalar>() &&
64-
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") {
65-
expr_eval.bind(fl->index(), tidx);
117+
if (fl->index()->isA<NamedScalar>()) {
118+
auto name = fl->index()->as<NamedScalar>()->name();
119+
TORCH_INTERNAL_ASSERT(
120+
isThreadIdx(name) || isBlockIdx(name), "unknow loop index");
66121
} else {
67-
expr_eval.bind(fl->index(), 0);
122+
auto start = expr_eval.evaluate(fl->start())->as<int64_t>();
123+
expr_eval.bind(fl->index(), start);
68124
}
69125
}
70126
for (auto ind : ti->indices()) {
@@ -89,17 +145,97 @@ int getConflictWays(const std::vector<int64_t>& addresses) {
89145
return conflict;
90146
}
91147

92-
} // namespace
148+
class InferLaunchParams : public kir::IrVisitor {
149+
public:
150+
static c10::optional<LaunchParams> get(
151+
const std::vector<Expr*>& exprs,
152+
const std::unordered_map<std::string, IntOrDouble>& known_values) {
153+
if (exprs.empty()) {
154+
return c10::nullopt;
155+
}
156+
return InferLaunchParams(exprs, known_values).launch_params_;
157+
}
158+
159+
private:
160+
InferLaunchParams(
161+
const std::vector<Expr*>& exprs,
162+
const std::unordered_map<std::string, IntOrDouble>& known_values)
163+
: expr_eval_(exprs[0]->fusion()) {
164+
for (auto pair : known_values) {
165+
expr_eval_.bind(pair.first, pair.second);
166+
}
167+
handle(exprs);
168+
}
169+
170+
using kir::IrVisitor::handle;
171+
172+
void handle(Expr* expr) final {
173+
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
174+
kir::IrVisitor::handle(expr);
175+
return;
176+
}
177+
178+
for (auto fl : for_loops_) {
179+
if (fl->index()->isA<NamedScalar>()) {
180+
auto name = fl->index()->as<NamedScalar>()->name();
181+
if (isThreadIdx(name) || isBlockIdx(name)) {
182+
auto ptype = getParallelType(name);
183+
auto stop = expr_eval_.evaluate(fl->stop());
184+
if (stop.has_value()) {
185+
if (!launch_params_.has_value()) {
186+
launch_params_ = LaunchParams();
187+
}
188+
if (launch_params_->getRawVal(ptype) ==
189+
LaunchParams::UNINITIALIZED_VAL) {
190+
launch_params_->bind(stop->as<int64_t>(), ptype);
191+
} else {
192+
TORCH_INTERNAL_ASSERT(
193+
launch_params_->getDim(ptype) == stop,
194+
"Unable to infer launch parameters");
195+
}
196+
}
197+
}
198+
}
199+
}
200+
}
201+
202+
ExpressionEvaluator expr_eval_;
203+
c10::optional<LaunchParams> launch_params_;
204+
};
93205

94206
class BankConflictInfo : public kir::IrVisitor {
95207
public:
96208
static std::unordered_map<const Expr*, std::pair<int, int>> get(
97-
const std::vector<Expr*>& exprs) {
98-
return BankConflictInfo(exprs).bank_conflict_info_;
209+
const std::vector<Expr*>& exprs,
210+
c10::optional<LaunchParams> launch_params,
211+
const std::unordered_map<std::string, IntOrDouble>& known_values) {
212+
if (exprs.empty()) {
213+
return {};
214+
}
215+
return BankConflictInfo(exprs, launch_params, known_values)
216+
.bank_conflict_info_;
99217
}
100218

101219
private:
102-
BankConflictInfo(const std::vector<Expr*>& exprs) {
220+
BankConflictInfo(
221+
const std::vector<Expr*>& exprs,
222+
c10::optional<LaunchParams> launch_params,
223+
const std::unordered_map<std::string, IntOrDouble>& known_values)
224+
: launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) {
225+
expr_eval_common_.bind("blockIdx.x", 0);
226+
expr_eval_common_.bind("blockIdx.y", 0);
227+
expr_eval_common_.bind("blockIdx.z", 0);
228+
if (launch_params.has_value()) {
229+
expr_eval_common_.bind("blockDim.x", launch_params->bdimx());
230+
expr_eval_common_.bind("blockDim.y", launch_params->bdimy());
231+
expr_eval_common_.bind("blockDim.z", launch_params->bdimz());
232+
expr_eval_common_.bind("gridDim.x", launch_params->gdimx());
233+
expr_eval_common_.bind("gridDim.y", launch_params->gdimy());
234+
expr_eval_common_.bind("gridDim.z", launch_params->gdimz());
235+
}
236+
for (auto pair : known_values) {
237+
expr_eval_common_.bind(pair.first, pair.second);
238+
}
103239
handle(exprs);
104240
}
105241

@@ -119,11 +255,17 @@ class BankConflictInfo : public kir::IrVisitor {
119255
std::pair<int, int> conflict_ways{0, 0};
120256
if (isSmemTensorIndex(uop->in())) {
121257
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
122-
uop->in()->as<kir::TensorIndex>(), for_loops_));
258+
uop->in()->as<kir::TensorIndex>(),
259+
for_loops_,
260+
launch_params_,
261+
expr_eval_common_));
123262
}
124263
if (isSmemTensorIndex(uop->out())) {
125264
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
126-
uop->out()->as<kir::TensorIndex>(), for_loops_));
265+
uop->out()->as<kir::TensorIndex>(),
266+
for_loops_,
267+
launch_params_,
268+
expr_eval_common_));
127269
}
128270
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
129271
bank_conflict_info_[expr] = conflict_ways;
@@ -133,11 +275,17 @@ class BankConflictInfo : public kir::IrVisitor {
133275
std::pair<int, int> conflict_ways{0, 0};
134276
if (isSmemTensorIndex(ldst->in())) {
135277
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
136-
ldst->in()->as<kir::TensorIndex>(), for_loops_));
278+
ldst->in()->as<kir::TensorIndex>(),
279+
for_loops_,
280+
launch_params_,
281+
expr_eval_common_));
137282
}
138283
if (isSmemTensorIndex(ldst->out())) {
139284
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
140-
ldst->out()->as<kir::TensorIndex>(), for_loops_));
285+
ldst->out()->as<kir::TensorIndex>(),
286+
for_loops_,
287+
launch_params_,
288+
expr_eval_common_));
141289
}
142290
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
143291
bank_conflict_info_[expr] = conflict_ways;
@@ -146,11 +294,36 @@ class BankConflictInfo : public kir::IrVisitor {
146294
}
147295

148296
std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_;
297+
c10::optional<LaunchParams> launch_params_;
298+
ExpressionEvaluator expr_eval_common_;
149299
};
150300

301+
} // namespace
302+
151303
std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
152-
kir::Kernel* kernel) {
153-
return BankConflictInfo::get(kernel->topLevelExprs());
304+
kir::Kernel* kernel,
305+
c10::optional<LaunchParams> launch_params,
306+
const std::unordered_map<std::string, IntOrDouble>& known_values) {
307+
for (auto pair : known_values) {
308+
TORCH_CHECK(
309+
!isThreadIdx(pair.first),
310+
"threadIdx.{x,y,z} should be computed instead of provided");
311+
TORCH_CHECK(
312+
!isBlockIdx(pair.first),
313+
"blockIdx.{x,y,z} should not be provided (they are always zero)");
314+
TORCH_CHECK(
315+
!isBlockDim(pair.first),
316+
"blockDim.{x,y,z} should be provided by launch_params");
317+
TORCH_CHECK(
318+
!isGridDim(pair.first),
319+
"gridDim.{x,y,z} should be provided by launch_params");
320+
}
321+
if (!launch_params.has_value()) {
322+
launch_params =
323+
InferLaunchParams::get(kernel->topLevelExprs(), known_values);
324+
}
325+
return BankConflictInfo::get(
326+
kernel->topLevelExprs(), launch_params, known_values);
154327
}
155328

156329
} // namespace cuda

torch/csrc/jit/codegen/cuda/lower_bank_conflict.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
4+
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
35
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
46
#include <torch/csrc/jit/codegen/cuda/kernel.h>
57

@@ -18,27 +20,25 @@ namespace cuda {
1820
// nsight compute. This utility currently has the following assumptions and
1921
// limitations:
2022
//
21-
// 1. This utility assumes that `blockDim.x` is large enough to hold one phase
22-
// 2. This utility assumes that the address only depends on loop variables
23-
// (there can not be a thing like `T0.stride[0]`, `blockDim.x`)
24-
// 3. This utility assumes that the data of the tensor is accessed by
23+
// 1. This utility assumes that the data of the tensor is accessed by
2524
// `T0[index]`, where `index` is the one stored in the `TensorIndex`
2625
// object.
27-
// 4. This utility only checks the first iteration, and the start of all
28-
// loop variables are assumed to be `0` (if we have something like
26+
// 2. This utility only checks the first iteration. If we have something like
2927
// `T1_s[tidx, 5]`, then different iterations should have different
30-
// results, which this utility will not be able to handle all of them now)
31-
// 5. This utility assumes that all tensors are independent, which means:
32-
// 5.1 All shared memory tensors are allocated starting from a multiple of
28+
// conflictions, which will not be evaluated for all of them
29+
// 3. This utility assumes that all tensors are independent, which means:
30+
// 3.1 All shared memory tensors are allocated starting from a multiple of
3331
// 4*32 bytes
34-
// 5.2 The only source of bank confliction is from within a tensor.
32+
// 3.2 The only source of bank confliction is from within a tensor.
3533
// There is no bank conflict between different tensors.
3634
//
3735
// Also note that this utility will not provide accurate estimation if the above
3836
// assumptions are satisfied
3937

4038
std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
41-
kir::Kernel* kernel);
39+
kir::Kernel* kernel,
40+
c10::optional<LaunchParams> launch_params = c10::nullopt,
41+
const std::unordered_map<std::string, IntOrDouble>& known_values = {});
4242

4343
} // namespace cuda
4444
} // namespace fuser

0 commit comments

Comments
 (0)