Skip to content

Commit afa23a6

Browse files
zheng-xqMikhail Zolotukhin
authored and
Mikhail Zolotukhin
committed
Add PrioritizeLoad to CudaCodeGen. (pytorch#179)
1 parent 8bbd8cb commit afa23a6

File tree

1 file changed

+117
-1
lines changed

1 file changed

+117
-1
lines changed

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,119 @@ void CudaPrinter::visit(const IfThenElse* v) {
189189
v->false_value().accept(this);
190190
}
191191

192+
class PrioritizeLoad : public IRMutator {
193+
public:
194+
virtual Expr mutate(const Load* v) {
195+
MemLoadList& load_list = load_stack_.back();
196+
Var load_new_var{"v", v->dtype()};
197+
Expr new_value = IRMutator::mutate(v);
198+
load_list.push_back(std::make_pair(load_new_var.node(), new_value));
199+
return load_new_var;
200+
}
201+
202+
// TODO: merge this with the IRMutator::mutate version.
203+
virtual Stmt mutate(const For* v) {
204+
Var var = v->var();
205+
Expr start = v->start();
206+
Expr stop = v->stop();
207+
Stmt body = v->body();
208+
LoopOptions loop_options = v->loop_options();
209+
Expr var_new_expr = var.accept_mutator(this);
210+
Var var_new = Var(var_new_expr.AsNode<Variable>());
211+
Expr start_new = start.accept_mutator(this);
212+
Expr stop_new = stop.accept_mutator(this);
213+
PushList();
214+
Stmt body_new = body.accept_mutator(this);
215+
Stmt body_with_loads = AddMemLoadsFromList(body_new);
216+
PopList();
217+
if (same_node(var, var_new) && same_node(start, start_new) &&
218+
same_node(stop, stop_new) && same_node(body, body_with_loads)) {
219+
return Stmt(v);
220+
}
221+
return For::make(
222+
var_new, start_new, stop_new, body_with_loads, loop_options);
223+
}
224+
225+
virtual Stmt mutate(const LetStmt* v) {
226+
Var var = v->var();
227+
Expr value = v->value();
228+
Stmt body = v->body();
229+
Expr var_new_expr = var.accept_mutator(this);
230+
Variable* var_new_ptr = var_new_expr.AsNode<Variable>();
231+
if (var_new_ptr == nullptr) {
232+
throw std::runtime_error("LetStmt var must be variable");
233+
}
234+
Var var_new{var_new_ptr};
235+
Expr value_new = value.accept_mutator(this);
236+
PushList();
237+
Stmt body_new = body.accept_mutator(this);
238+
Stmt body_with_loads = AddMemLoadsFromList(body_new);
239+
PopList();
240+
if (same_node(var, var_new) && same_node(value, value_new) &&
241+
same_node(body, body_with_loads)) {
242+
return Stmt(v);
243+
}
244+
return LetStmt::make(var_new, value_new, body_with_loads);
245+
}
246+
247+
virtual Stmt mutate(const Cond* v) {
248+
Expr cond_old = v->condition();
249+
Stmt true_old = v->true_stmt();
250+
Stmt false_old = v->false_stmt();
251+
252+
Expr cond_new = cond_old.accept_mutator(this);
253+
PushList();
254+
Stmt true_new = true_old.accept_mutator(this);
255+
Stmt true_with_loads = AddMemLoadsFromList(true_new);
256+
PopList();
257+
PushList();
258+
Stmt false_new = false_old.accept_mutator(this);
259+
Stmt false_with_loads = AddMemLoadsFromList(false_new);
260+
PopList();
261+
262+
if (same_node(cond_old, cond_new) && same_node(true_old, true_with_loads) &&
263+
same_node(false_old, false_with_loads)) {
264+
return Stmt(v);
265+
}
266+
return Cond::make(cond_new, true_with_loads, false_with_loads);
267+
}
268+
269+
Stmt Process(const Stmt& stmt) {
270+
this->PushList();
271+
Stmt stmt_v = stmt;
272+
Stmt stmt_new = stmt_v.accept_mutator(this);
273+
Stmt stmt_with_loads = AddMemLoadsFromList(stmt_new);
274+
this->PopList();
275+
return stmt_with_loads;
276+
}
277+
278+
private:
279+
using MemLoadEntry = std::pair<const Variable*, Expr>;
280+
using MemLoadList = std::vector<MemLoadEntry>;
281+
using MemoryLoadStack = std::vector<MemLoadList>;
282+
283+
void PushList() {
284+
load_stack_.push_back(MemLoadList());
285+
}
286+
287+
void PopList() {
288+
load_stack_.pop_back();
289+
}
290+
291+
Stmt AddMemLoadsFromList(const Stmt& stmt) {
292+
MemLoadList& load_list = load_stack_.back();
293+
Stmt stmt_v = stmt;
294+
for (int i = load_list.size() - 1; i >= 0; i--) {
295+
const MemLoadEntry& entry = load_list[i];
296+
Variable* var_ptr = const_cast<Variable*>(entry.first);
297+
stmt_v = LetStmt::make(Var(var_ptr), entry.second, stmt_v);
298+
}
299+
return stmt_v;
300+
}
301+
302+
MemoryLoadStack load_stack_;
303+
};
304+
192305
void CudaCodeGen::Initialize() {
193306
printer_.reset(new CudaPrinter(&oss_));
194307
// TODO: handle multiple kernels.
@@ -209,7 +322,10 @@ void CudaCodeGen::Initialize() {
209322
os() << ") {";
210323

211324
os() << std::endl;
212-
stmt().accept(printer_.get());
325+
Stmt stmt_v = stmt();
326+
PrioritizeLoad prioritize_load;
327+
stmt_v = prioritize_load.Process(stmt_v);
328+
stmt_v.accept(printer_.get());
213329
os() << std::endl;
214330
os() << "}";
215331

0 commit comments

Comments
 (0)