1
1
#include < torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
2
2
3
+ #include < torch/csrc/jit/codegen/cuda/dynamic_type.h>
3
4
#include < torch/csrc/jit/codegen/cuda/expr_evaluator.h>
4
5
#include < torch/csrc/jit/codegen/cuda/kernel_ir.h>
5
6
#include < torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
@@ -48,23 +49,78 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) {
48
49
return 32 ;
49
50
}
50
51
52
+ bool isThreadIdx (const std::string& name) {
53
+ return name == " threadIdx.x" || name == " threadIdx.y" ||
54
+ name == " threadIdx.z" ;
55
+ }
56
+
57
+ bool isBlockIdx (const std::string& name) {
58
+ return name == " blockIdx.x" || name == " blockIdx.y" || name == " blockIdx.z" ;
59
+ }
60
+
61
+ bool isBlockDim (const std::string& name) {
62
+ return name == " blockDim.x" && name == " blockDim.y" && name == " blockDim.z" ;
63
+ }
64
+
65
+ bool isGridDim (const std::string& name) {
66
+ return name == " gridDim.x" && name == " gridDim.y" && name == " gridDim.z" ;
67
+ }
68
+
69
+ ParallelType getParallelType (const std::string& name) {
70
+ if (name == " threadIdx.x" ) {
71
+ return ParallelType::TIDx;
72
+ } else if (name == " threadIdx.y" ) {
73
+ return ParallelType::TIDy;
74
+ } else if (name == " threadIdx.z" ) {
75
+ return ParallelType::TIDz;
76
+ } else if (name == " blockIdx.x" ) {
77
+ return ParallelType::BIDx;
78
+ } else if (name == " blockIdx.y" ) {
79
+ return ParallelType::BIDy;
80
+ } else if (name == " blockIdx.z" ) {
81
+ return ParallelType::BIDz;
82
+ }
83
+ TORCH_INTERNAL_ASSERT (false , " Not a parallel type" );
84
+ }
85
+
51
86
std::vector<int64_t > evaluateAddressesOnFirstPhase (
52
87
kir::TensorIndex* ti,
53
- const std::vector<kir::ForLoop*>& for_loops) {
88
+ const std::vector<kir::ForLoop*>& for_loops,
89
+ c10::optional<LaunchParams> launch_params,
90
+ const ExpressionEvaluator& expr_eval_common) {
54
91
std::vector<int64_t > addresses;
55
92
const auto word_size_bytes =
56
93
dataTypeSize (*(ti->getDataType ())) * getVectorizeSize (ti);
57
94
int64_t phase_size = getPhaseSize (word_size_bytes);
58
95
59
- for (auto tidx : c10::irange (phase_size)) {
96
+ if (launch_params.has_value ()) {
97
+ phase_size = std::min<int64_t >(phase_size, launch_params->nThreads ());
98
+ }
99
+
100
+ for (int64_t linear_tidx : c10::irange (phase_size)) {
101
+ int64_t tidx = linear_tidx;
102
+ int64_t tidy = 0 ;
103
+ int64_t tidz = 0 ;
104
+ if (launch_params.has_value ()) {
105
+ tidy = tidx / launch_params->bdimx ();
106
+ tidx = tidx % launch_params->bdimx ();
107
+ tidz = tidy / launch_params->bdimy ();
108
+ tidy = tidy % launch_params->bdimy ();
109
+ }
60
110
int64_t index = 0 ;
61
- ExpressionEvaluator expr_eval (ti->fusion ());
111
+ // make a copy of the expression evaluator
112
+ ExpressionEvaluator expr_eval = expr_eval_common;
113
+ expr_eval.bind (" threadIdx.x" , tidx);
114
+ expr_eval.bind (" threadIdx.y" , tidy);
115
+ expr_eval.bind (" threadIdx.z" , tidz);
62
116
for (auto fl : for_loops) {
63
- if (fl->index ()->isA <NamedScalar>() &&
64
- fl->index ()->as <NamedScalar>()->name () == " threadIdx.x" ) {
65
- expr_eval.bind (fl->index (), tidx);
117
+ if (fl->index ()->isA <NamedScalar>()) {
118
+ auto name = fl->index ()->as <NamedScalar>()->name ();
119
+ TORCH_INTERNAL_ASSERT (
120
+ isThreadIdx (name) || isBlockIdx (name), " unknow loop index" );
66
121
} else {
67
- expr_eval.bind (fl->index (), 0 );
122
+ auto start = expr_eval.evaluate (fl->start ())->as <int64_t >();
123
+ expr_eval.bind (fl->index (), start);
68
124
}
69
125
}
70
126
for (auto ind : ti->indices ()) {
@@ -89,17 +145,97 @@ int getConflictWays(const std::vector<int64_t>& addresses) {
89
145
return conflict;
90
146
}
91
147
92
- } // namespace
148
+ class InferLaunchParams : public kir ::IrVisitor {
149
+ public:
150
+ static c10::optional<LaunchParams> get (
151
+ const std::vector<Expr*>& exprs,
152
+ const std::unordered_map<std::string, IntOrDouble>& known_values) {
153
+ if (exprs.empty ()) {
154
+ return c10::nullopt;
155
+ }
156
+ return InferLaunchParams (exprs, known_values).launch_params_ ;
157
+ }
158
+
159
+ private:
160
+ InferLaunchParams (
161
+ const std::vector<Expr*>& exprs,
162
+ const std::unordered_map<std::string, IntOrDouble>& known_values)
163
+ : expr_eval_(exprs[0 ]->fusion ()) {
164
+ for (auto pair : known_values) {
165
+ expr_eval_.bind (pair.first , pair.second );
166
+ }
167
+ handle (exprs);
168
+ }
169
+
170
+ using kir::IrVisitor::handle;
171
+
172
+ void handle (Expr* expr) final {
173
+ if (expr->isA <kir::ForLoop>() || expr->isA <kir::IfThenElse>()) {
174
+ kir::IrVisitor::handle (expr);
175
+ return ;
176
+ }
177
+
178
+ for (auto fl : for_loops_) {
179
+ if (fl->index ()->isA <NamedScalar>()) {
180
+ auto name = fl->index ()->as <NamedScalar>()->name ();
181
+ if (isThreadIdx (name) || isBlockIdx (name)) {
182
+ auto ptype = getParallelType (name);
183
+ auto stop = expr_eval_.evaluate (fl->stop ());
184
+ if (stop.has_value ()) {
185
+ if (!launch_params_.has_value ()) {
186
+ launch_params_ = LaunchParams ();
187
+ }
188
+ if (launch_params_->getRawVal (ptype) ==
189
+ LaunchParams::UNINITIALIZED_VAL) {
190
+ launch_params_->bind (stop->as <int64_t >(), ptype);
191
+ } else {
192
+ TORCH_INTERNAL_ASSERT (
193
+ launch_params_->getDim (ptype) == stop,
194
+ " Unable to infer launch parameters" );
195
+ }
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ ExpressionEvaluator expr_eval_;
203
+ c10::optional<LaunchParams> launch_params_;
204
+ };
93
205
94
206
class BankConflictInfo : public kir ::IrVisitor {
95
207
public:
96
208
static std::unordered_map<const Expr*, std::pair<int , int >> get (
97
- const std::vector<Expr*>& exprs) {
98
- return BankConflictInfo (exprs).bank_conflict_info_ ;
209
+ const std::vector<Expr*>& exprs,
210
+ c10::optional<LaunchParams> launch_params,
211
+ const std::unordered_map<std::string, IntOrDouble>& known_values) {
212
+ if (exprs.empty ()) {
213
+ return {};
214
+ }
215
+ return BankConflictInfo (exprs, launch_params, known_values)
216
+ .bank_conflict_info_ ;
99
217
}
100
218
101
219
private:
102
- BankConflictInfo (const std::vector<Expr*>& exprs) {
220
+ BankConflictInfo (
221
+ const std::vector<Expr*>& exprs,
222
+ c10::optional<LaunchParams> launch_params,
223
+ const std::unordered_map<std::string, IntOrDouble>& known_values)
224
+ : launch_params_(launch_params), expr_eval_common_(exprs[0 ]->fusion ()) {
225
+ expr_eval_common_.bind (" blockIdx.x" , 0 );
226
+ expr_eval_common_.bind (" blockIdx.y" , 0 );
227
+ expr_eval_common_.bind (" blockIdx.z" , 0 );
228
+ if (launch_params.has_value ()) {
229
+ expr_eval_common_.bind (" blockDim.x" , launch_params->bdimx ());
230
+ expr_eval_common_.bind (" blockDim.y" , launch_params->bdimy ());
231
+ expr_eval_common_.bind (" blockDim.z" , launch_params->bdimz ());
232
+ expr_eval_common_.bind (" gridDim.x" , launch_params->gdimx ());
233
+ expr_eval_common_.bind (" gridDim.y" , launch_params->gdimy ());
234
+ expr_eval_common_.bind (" gridDim.z" , launch_params->gdimz ());
235
+ }
236
+ for (auto pair : known_values) {
237
+ expr_eval_common_.bind (pair.first , pair.second );
238
+ }
103
239
handle (exprs);
104
240
}
105
241
@@ -119,11 +255,17 @@ class BankConflictInfo : public kir::IrVisitor {
119
255
std::pair<int , int > conflict_ways{0 , 0 };
120
256
if (isSmemTensorIndex (uop->in ())) {
121
257
conflict_ways.first = getConflictWays (evaluateAddressesOnFirstPhase (
122
- uop->in ()->as <kir::TensorIndex>(), for_loops_));
258
+ uop->in ()->as <kir::TensorIndex>(),
259
+ for_loops_,
260
+ launch_params_,
261
+ expr_eval_common_));
123
262
}
124
263
if (isSmemTensorIndex (uop->out ())) {
125
264
conflict_ways.second = getConflictWays (evaluateAddressesOnFirstPhase (
126
- uop->out ()->as <kir::TensorIndex>(), for_loops_));
265
+ uop->out ()->as <kir::TensorIndex>(),
266
+ for_loops_,
267
+ launch_params_,
268
+ expr_eval_common_));
127
269
}
128
270
if (conflict_ways.first > 1 || conflict_ways.second > 1 ) {
129
271
bank_conflict_info_[expr] = conflict_ways;
@@ -133,11 +275,17 @@ class BankConflictInfo : public kir::IrVisitor {
133
275
std::pair<int , int > conflict_ways{0 , 0 };
134
276
if (isSmemTensorIndex (ldst->in ())) {
135
277
conflict_ways.first = getConflictWays (evaluateAddressesOnFirstPhase (
136
- ldst->in ()->as <kir::TensorIndex>(), for_loops_));
278
+ ldst->in ()->as <kir::TensorIndex>(),
279
+ for_loops_,
280
+ launch_params_,
281
+ expr_eval_common_));
137
282
}
138
283
if (isSmemTensorIndex (ldst->out ())) {
139
284
conflict_ways.second = getConflictWays (evaluateAddressesOnFirstPhase (
140
- ldst->out ()->as <kir::TensorIndex>(), for_loops_));
285
+ ldst->out ()->as <kir::TensorIndex>(),
286
+ for_loops_,
287
+ launch_params_,
288
+ expr_eval_common_));
141
289
}
142
290
if (conflict_ways.first > 1 || conflict_ways.second > 1 ) {
143
291
bank_conflict_info_[expr] = conflict_ways;
@@ -146,11 +294,36 @@ class BankConflictInfo : public kir::IrVisitor {
146
294
}
147
295
148
296
std::unordered_map<const Expr*, std::pair<int , int >> bank_conflict_info_;
297
+ c10::optional<LaunchParams> launch_params_;
298
+ ExpressionEvaluator expr_eval_common_;
149
299
};
150
300
301
+ } // namespace
302
+
151
303
std::unordered_map<const Expr*, std::pair<int , int >> getBankConflictInfo (
152
- kir::Kernel* kernel) {
153
- return BankConflictInfo::get (kernel->topLevelExprs ());
304
+ kir::Kernel* kernel,
305
+ c10::optional<LaunchParams> launch_params,
306
+ const std::unordered_map<std::string, IntOrDouble>& known_values) {
307
+ for (auto pair : known_values) {
308
+ TORCH_CHECK (
309
+ !isThreadIdx (pair.first ),
310
+ " threadIdx.{x,y,z} should be computed instead of provided" );
311
+ TORCH_CHECK (
312
+ !isBlockIdx (pair.first ),
313
+ " blockIdx.{x,y,z} should not be provided (they are always zero)" );
314
+ TORCH_CHECK (
315
+ !isBlockDim (pair.first ),
316
+ " blockDim.{x,y,z} should be provided by launch_params" );
317
+ TORCH_CHECK (
318
+ !isGridDim (pair.first ),
319
+ " gridDim.{x,y,z} should be provided by launch_params" );
320
+ }
321
+ if (!launch_params.has_value ()) {
322
+ launch_params =
323
+ InferLaunchParams::get (kernel->topLevelExprs (), known_values);
324
+ }
325
+ return BankConflictInfo::get (
326
+ kernel->topLevelExprs (), launch_params, known_values);
154
327
}
155
328
156
329
} // namespace cuda
0 commit comments