@@ -189,6 +189,119 @@ void CudaPrinter::visit(const IfThenElse* v) {
189
189
v->false_value ().accept (this );
190
190
}
191
191
192
+ class PrioritizeLoad : public IRMutator {
193
+ public:
194
+ virtual Expr mutate (const Load* v) {
195
+ MemLoadList& load_list = load_stack_.back ();
196
+ Var load_new_var{" v" , v->dtype ()};
197
+ Expr new_value = IRMutator::mutate (v);
198
+ load_list.push_back (std::make_pair (load_new_var.node (), new_value));
199
+ return load_new_var;
200
+ }
201
+
202
+ // TODO: merge this with the IRMutator::mutate version.
203
+ virtual Stmt mutate (const For* v) {
204
+ Var var = v->var ();
205
+ Expr start = v->start ();
206
+ Expr stop = v->stop ();
207
+ Stmt body = v->body ();
208
+ LoopOptions loop_options = v->loop_options ();
209
+ Expr var_new_expr = var.accept_mutator (this );
210
+ Var var_new = Var (var_new_expr.AsNode <Variable>());
211
+ Expr start_new = start.accept_mutator (this );
212
+ Expr stop_new = stop.accept_mutator (this );
213
+ PushList ();
214
+ Stmt body_new = body.accept_mutator (this );
215
+ Stmt body_with_loads = AddMemLoadsFromList (body_new);
216
+ PopList ();
217
+ if (same_node (var, var_new) && same_node (start, start_new) &&
218
+ same_node (stop, stop_new) && same_node (body, body_with_loads)) {
219
+ return Stmt (v);
220
+ }
221
+ return For::make (
222
+ var_new, start_new, stop_new, body_with_loads, loop_options);
223
+ }
224
+
225
+ virtual Stmt mutate (const LetStmt* v) {
226
+ Var var = v->var ();
227
+ Expr value = v->value ();
228
+ Stmt body = v->body ();
229
+ Expr var_new_expr = var.accept_mutator (this );
230
+ Variable* var_new_ptr = var_new_expr.AsNode <Variable>();
231
+ if (var_new_ptr == nullptr ) {
232
+ throw std::runtime_error (" LetStmt var must be variable" );
233
+ }
234
+ Var var_new{var_new_ptr};
235
+ Expr value_new = value.accept_mutator (this );
236
+ PushList ();
237
+ Stmt body_new = body.accept_mutator (this );
238
+ Stmt body_with_loads = AddMemLoadsFromList (body_new);
239
+ PopList ();
240
+ if (same_node (var, var_new) && same_node (value, value_new) &&
241
+ same_node (body, body_with_loads)) {
242
+ return Stmt (v);
243
+ }
244
+ return LetStmt::make (var_new, value_new, body_with_loads);
245
+ }
246
+
247
+ virtual Stmt mutate (const Cond* v) {
248
+ Expr cond_old = v->condition ();
249
+ Stmt true_old = v->true_stmt ();
250
+ Stmt false_old = v->false_stmt ();
251
+
252
+ Expr cond_new = cond_old.accept_mutator (this );
253
+ PushList ();
254
+ Stmt true_new = true_old.accept_mutator (this );
255
+ Stmt true_with_loads = AddMemLoadsFromList (true_new);
256
+ PopList ();
257
+ PushList ();
258
+ Stmt false_new = false_old.accept_mutator (this );
259
+ Stmt false_with_loads = AddMemLoadsFromList (false_new);
260
+ PopList ();
261
+
262
+ if (same_node (cond_old, cond_new) && same_node (true_old, true_with_loads) &&
263
+ same_node (false_old, false_with_loads)) {
264
+ return Stmt (v);
265
+ }
266
+ return Cond::make (cond_new, true_with_loads, false_with_loads);
267
+ }
268
+
269
+ Stmt Process (const Stmt& stmt) {
270
+ this ->PushList ();
271
+ Stmt stmt_v = stmt;
272
+ Stmt stmt_new = stmt_v.accept_mutator (this );
273
+ Stmt stmt_with_loads = AddMemLoadsFromList (stmt_new);
274
+ this ->PopList ();
275
+ return stmt_with_loads;
276
+ }
277
+
278
+ private:
279
+ using MemLoadEntry = std::pair<const Variable*, Expr>;
280
+ using MemLoadList = std::vector<MemLoadEntry>;
281
+ using MemoryLoadStack = std::vector<MemLoadList>;
282
+
283
+ void PushList () {
284
+ load_stack_.push_back (MemLoadList ());
285
+ }
286
+
287
+ void PopList () {
288
+ load_stack_.pop_back ();
289
+ }
290
+
291
+ Stmt AddMemLoadsFromList (const Stmt& stmt) {
292
+ MemLoadList& load_list = load_stack_.back ();
293
+ Stmt stmt_v = stmt;
294
+ for (int i = load_list.size () - 1 ; i >= 0 ; i--) {
295
+ const MemLoadEntry& entry = load_list[i];
296
+ Variable* var_ptr = const_cast <Variable*>(entry.first );
297
+ stmt_v = LetStmt::make (Var (var_ptr), entry.second , stmt_v);
298
+ }
299
+ return stmt_v;
300
+ }
301
+
302
+ MemoryLoadStack load_stack_;
303
+ };
304
+
192
305
void CudaCodeGen::Initialize () {
193
306
printer_.reset (new CudaPrinter (&oss_));
194
307
// TODO: handle multiple kernels.
@@ -209,7 +322,10 @@ void CudaCodeGen::Initialize() {
209
322
os () << " ) {" ;
210
323
211
324
os () << std::endl;
212
- stmt ().accept (printer_.get ());
325
+ Stmt stmt_v = stmt ();
326
+ PrioritizeLoad prioritize_load;
327
+ stmt_v = prioritize_load.Process (stmt_v);
328
+ stmt_v.accept (printer_.get ());
213
329
os () << std::endl;
214
330
os () << " }" ;
215
331
0 commit comments