Skip to content

Commit cede340

Browse files
authored
Misc cleanup (#109)
1 parent dca5dba commit cede340

File tree

2 files changed

+82
-73
lines changed

2 files changed

+82
-73
lines changed

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 75 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ T* ptr(T* obj) {
3636
* }
3737
*
3838
* And therefore dispatch should never call:
39-
* ptr(mutator)->handle(static_cast<Statement*>(this));
39+
* ptr(mutator)->handle(this->as<Statement>());
4040
*/
4141

4242
template <typename T>
@@ -45,35 +45,35 @@ void Val::dispatch(T handler, Val* val) {
4545
case ValType::Scalar:
4646
switch (*(val->getDataType())) {
4747
case DataType::Bool:
48-
ptr(handler)->handle(static_cast<Bool*>(val));
48+
ptr(handler)->handle(val->as<Bool>());
4949
return;
5050
case DataType::Float:
51-
ptr(handler)->handle(static_cast<Float*>(val));
51+
ptr(handler)->handle(val->as<Float>());
5252
return;
5353
case DataType::Half:
54-
ptr(handler)->handle(static_cast<Half*>(val));
54+
ptr(handler)->handle(val->as<Half>());
5555
return;
5656
case DataType::Int:
57-
ptr(handler)->handle(static_cast<Int*>(val));
57+
ptr(handler)->handle(val->as<Int>());
5858
return;
5959
default:
6060
break;
6161
}
6262
break;
6363
case ValType::IterDomain:
64-
ptr(handler)->handle(static_cast<IterDomain*>(val));
64+
ptr(handler)->handle(val->as<IterDomain>());
6565
return;
6666
case ValType::TensorDomain:
67-
ptr(handler)->handle(static_cast<TensorDomain*>(val));
67+
ptr(handler)->handle(val->as<TensorDomain>());
6868
return;
6969
case ValType::TensorView:
70-
ptr(handler)->handle(static_cast<TensorView*>(val));
70+
ptr(handler)->handle(val->as<TensorView>());
7171
return;
7272
case ValType::TensorIndex:
73-
ptr(handler)->handle(static_cast<TensorIndex*>(val));
73+
ptr(handler)->handle(val->as<TensorIndex>());
7474
return;
7575
case ValType::NamedScalar:
76-
ptr(handler)->handle(static_cast<NamedScalar*>(val));
76+
ptr(handler)->handle(val->as<NamedScalar>());
7777
return;
7878
default:
7979
break;
@@ -85,34 +85,34 @@ template <typename T>
8585
void Expr::dispatch(T handler, Expr* expr) {
8686
switch (*(expr->getExprType())) {
8787
case ExprType::Split:
88-
ptr(handler)->handle(static_cast<Split*>(expr));
88+
ptr(handler)->handle(expr->as<Split>());
8989
return;
9090
case ExprType::Merge:
91-
ptr(handler)->handle(static_cast<Merge*>(expr));
91+
ptr(handler)->handle(expr->as<Merge>());
9292
return;
9393
case ExprType::UnaryOp:
94-
ptr(handler)->handle(static_cast<UnaryOp*>(expr));
94+
ptr(handler)->handle(expr->as<UnaryOp>());
9595
return;
9696
case ExprType::BinaryOp:
97-
ptr(handler)->handle(static_cast<BinaryOp*>(expr));
97+
ptr(handler)->handle(expr->as<BinaryOp>());
9898
return;
9999
case ExprType::TernaryOp:
100-
ptr(handler)->handle(static_cast<TernaryOp*>(expr));
100+
ptr(handler)->handle(expr->as<TernaryOp>());
101101
return;
102102
case ExprType::ReductionOp:
103-
ptr(handler)->handle(static_cast<ReductionOp*>(expr));
103+
ptr(handler)->handle(expr->as<ReductionOp>());
104104
return;
105105
case ExprType::BroadcastOp:
106-
ptr(handler)->handle(static_cast<BroadcastOp*>(expr));
106+
ptr(handler)->handle(expr->as<BroadcastOp>());
107107
return;
108108
case ExprType::ForLoop:
109-
ptr(handler)->handle(static_cast<ForLoop*>(expr));
109+
ptr(handler)->handle(expr->as<ForLoop>());
110110
return;
111111
case ExprType::IfThenElse:
112-
ptr(handler)->handle(static_cast<IfThenElse*>(expr));
112+
ptr(handler)->handle(expr->as<IfThenElse>());
113113
return;
114114
case ExprType::Allocate:
115-
ptr(handler)->handle(static_cast<Allocate*>(expr));
115+
ptr(handler)->handle(expr->as<Allocate>());
116116
return;
117117
default:
118118
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -122,9 +122,9 @@ void Expr::dispatch(T handler, Expr* expr) {
122122
template <typename T>
123123
void Statement::dispatch(T handler, Statement* stmt) {
124124
if (stmt->isVal()) {
125-
ptr(handler)->handle(static_cast<Val*>(stmt));
125+
ptr(handler)->handle(stmt->as<Val>());
126126
} else if (stmt->isExpr()) {
127-
ptr(handler)->handle(static_cast<Expr*>(stmt));
127+
ptr(handler)->handle(stmt->as<Expr>());
128128
} else
129129
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
130130
}
@@ -135,35 +135,35 @@ void Val::constDispatch(T handler, const Val* val) {
135135
case ValType::Scalar:
136136
switch (*(val->getDataType())) {
137137
case DataType::Bool:
138-
ptr(handler)->handle(static_cast<const Bool*>(val));
138+
ptr(handler)->handle(val->as<Bool>());
139139
return;
140140
case DataType::Float:
141-
ptr(handler)->handle(static_cast<const Float*>(val));
141+
ptr(handler)->handle(val->as<Float>());
142142
return;
143143
case DataType::Half:
144-
ptr(handler)->handle(static_cast<const Half*>(val));
144+
ptr(handler)->handle(val->as<Half>());
145145
return;
146146
case DataType::Int:
147-
ptr(handler)->handle(static_cast<const Int*>(val));
147+
ptr(handler)->handle(val->as<Int>());
148148
return;
149149
default:
150150
break;
151151
}
152152
break;
153153
case ValType::IterDomain:
154-
ptr(handler)->handle(static_cast<const IterDomain*>(val));
154+
ptr(handler)->handle(val->as<IterDomain>());
155155
return;
156156
case ValType::TensorDomain:
157-
ptr(handler)->handle(static_cast<const TensorDomain*>(val));
157+
ptr(handler)->handle(val->as<TensorDomain>());
158158
return;
159159
case ValType::TensorView:
160-
ptr(handler)->handle(static_cast<const TensorView*>(val));
160+
ptr(handler)->handle(val->as<TensorView>());
161161
return;
162162
case ValType::TensorIndex:
163-
ptr(handler)->handle(static_cast<const TensorIndex*>(val));
163+
ptr(handler)->handle(val->as<TensorIndex>());
164164
return;
165165
case ValType::NamedScalar:
166-
ptr(handler)->handle(static_cast<const NamedScalar*>(val));
166+
ptr(handler)->handle(val->as<NamedScalar>());
167167
return;
168168
default:
169169
break;
@@ -175,34 +175,34 @@ template <typename T>
175175
void Expr::constDispatch(T handler, const Expr* expr) {
176176
switch (*(expr->getExprType())) {
177177
case ExprType::Split:
178-
ptr(handler)->handle(static_cast<const Split*>(expr));
178+
ptr(handler)->handle(expr->as<Split>());
179179
return;
180180
case ExprType::Merge:
181-
ptr(handler)->handle(static_cast<const Merge*>(expr));
181+
ptr(handler)->handle(expr->as<Merge>());
182182
return;
183183
case ExprType::UnaryOp:
184-
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
184+
ptr(handler)->handle(expr->as<UnaryOp>());
185185
return;
186186
case ExprType::BinaryOp:
187-
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
187+
ptr(handler)->handle(expr->as<BinaryOp>());
188188
return;
189189
case ExprType::TernaryOp:
190-
ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
190+
ptr(handler)->handle(expr->as<TernaryOp>());
191191
return;
192192
case ExprType::ReductionOp:
193-
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
193+
ptr(handler)->handle(expr->as<ReductionOp>());
194194
return;
195195
case ExprType::BroadcastOp:
196-
ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
196+
ptr(handler)->handle(expr->as<BroadcastOp>());
197197
return;
198198
case ExprType::ForLoop:
199-
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
199+
ptr(handler)->handle(expr->as<ForLoop>());
200200
return;
201201
case ExprType::IfThenElse:
202-
ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
202+
ptr(handler)->handle(expr->as<IfThenElse>());
203203
return;
204204
case ExprType::Allocate:
205-
ptr(handler)->handle(static_cast<const Allocate*>(expr));
205+
ptr(handler)->handle(expr->as<Allocate>());
206206
return;
207207
default:
208208
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -212,9 +212,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
212212
template <typename T>
213213
void Statement::constDispatch(T handler, const Statement* stmt) {
214214
if (stmt->isVal()) {
215-
ptr(handler)->handle(static_cast<const Val*>(stmt));
215+
ptr(handler)->handle(stmt->as<Val>());
216216
} else if (stmt->isExpr()) {
217-
ptr(handler)->handle(static_cast<const Expr*>(stmt));
217+
ptr(handler)->handle(stmt->as<Expr>());
218218
} else
219219
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
220220
}
@@ -228,35 +228,35 @@ void Statement::constDispatch(T handler, const Statement* stmt) {
228228
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
229229
* }
230230
* And therefore dispatch should never call:
231-
* ptr(mutator)->mutate(static_cast<Statement*>(this));
231+
* ptr(mutator)->mutate(this->as<Statement>());
232232
*/
233233
template <typename T>
234234
Statement* Val::mutatorDispatch(T mutator, Val* val) {
235235
switch (*(val->getValType())) {
236236
case ValType::Scalar:
237237
switch (*(val->getDataType())) {
238238
case DataType::Bool:
239-
return ptr(mutator)->mutate(static_cast<Bool*>(val));
239+
return ptr(mutator)->mutate(val->as<Bool>());
240240
case DataType::Float:
241-
return ptr(mutator)->mutate(static_cast<Float*>(val));
241+
return ptr(mutator)->mutate(val->as<Float>());
242242
case DataType::Half:
243-
return ptr(mutator)->mutate(static_cast<Half*>(val));
243+
return ptr(mutator)->mutate(val->as<Half>());
244244
case DataType::Int:
245-
return ptr(mutator)->mutate(static_cast<Int*>(val));
245+
return ptr(mutator)->mutate(val->as<Int>());
246246
default:
247247
break;
248248
}
249249
break;
250250
case ValType::IterDomain:
251-
return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
251+
return ptr(mutator)->mutate(val->as<IterDomain>());
252252
case ValType::TensorDomain:
253-
return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
253+
return ptr(mutator)->mutate(val->as<TensorDomain>());
254254
case ValType::TensorView:
255-
return ptr(mutator)->mutate(static_cast<TensorView*>(val));
255+
return ptr(mutator)->mutate(val->as<TensorView>());
256256
case ValType::TensorIndex:
257-
return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
257+
return ptr(mutator)->mutate(val->as<TensorIndex>());
258258
case ValType::NamedScalar:
259-
return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
259+
return ptr(mutator)->mutate(val->as<NamedScalar>());
260260
default:
261261
break;
262262
}
@@ -267,25 +267,25 @@ template <typename T>
267267
Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
268268
switch (*(expr->getExprType())) {
269269
case ExprType::Split:
270-
return ptr(mutator)->mutate(static_cast<Split*>(expr));
270+
return ptr(mutator)->mutate(expr->as<Split>());
271271
case ExprType::Merge:
272-
return ptr(mutator)->mutate(static_cast<Merge*>(expr));
272+
return ptr(mutator)->mutate(expr->as<Merge>());
273273
case ExprType::UnaryOp:
274-
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
274+
return ptr(mutator)->mutate(expr->as<UnaryOp>());
275275
case ExprType::BinaryOp:
276-
return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
276+
return ptr(mutator)->mutate(expr->as<BinaryOp>());
277277
case ExprType::TernaryOp:
278-
return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
278+
return ptr(mutator)->mutate(expr->as<TernaryOp>());
279279
case ExprType::ReductionOp:
280-
return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
280+
return ptr(mutator)->mutate(expr->as<ReductionOp>());
281281
case ExprType::BroadcastOp:
282-
return ptr(mutator)->mutate(static_cast<BroadcastOp*>(expr));
282+
return ptr(mutator)->mutate(expr->as<BroadcastOp>());
283283
case ExprType::ForLoop:
284-
return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
284+
return ptr(mutator)->mutate(expr->as<ForLoop>());
285285
case ExprType::IfThenElse:
286-
return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
286+
return ptr(mutator)->mutate(expr->as<IfThenElse>());
287287
case ExprType::Allocate:
288-
return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
288+
return ptr(mutator)->mutate(expr->as<Allocate>());
289289
default:
290290
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
291291
}
@@ -294,10 +294,10 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
294294
template <typename T>
295295
Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
296296
if (stmt->isVal()) {
297-
return ptr(mutator)->mutate(static_cast<Val*>(stmt));
297+
return ptr(mutator)->mutate(stmt->as<Val>());
298298
}
299299
if (stmt->isExpr()) {
300-
return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
300+
return ptr(mutator)->mutate(stmt->as<Expr>());
301301
}
302302
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
303303
}
@@ -352,39 +352,47 @@ template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*);
352352
void OptOutDispatch::handle(Statement* s) {
353353
Statement::dispatch(this, s);
354354
}
355+
355356
void OptOutDispatch::handle(Expr* e) {
356357
Expr::dispatch(this, e);
357358
}
359+
358360
void OptOutDispatch::handle(Val* v) {
359361
Val::dispatch(this, v);
360362
}
361363

362364
void OptInDispatch::handle(Statement* s) {
363365
Statement::dispatch(this, s);
364366
}
367+
365368
void OptInDispatch::handle(Expr* e) {
366369
Expr::dispatch(this, e);
367370
}
371+
368372
void OptInDispatch::handle(Val* v) {
369373
Val::dispatch(this, v);
370374
}
371375

372376
void OptOutConstDispatch::handle(const Statement* s) {
373377
Statement::constDispatch(this, s);
374378
}
379+
375380
void OptOutConstDispatch::handle(const Expr* e) {
376381
Expr::constDispatch(this, e);
377382
}
383+
378384
void OptOutConstDispatch::handle(const Val* v) {
379385
Val::constDispatch(this, v);
380386
}
381387

382388
void OptInConstDispatch::handle(const Statement* s) {
383389
Statement::constDispatch(this, s);
384390
}
391+
385392
void OptInConstDispatch::handle(const Expr* e) {
386393
Expr::constDispatch(this, e);
387394
}
395+
388396
void OptInConstDispatch::handle(const Val* v) {
389397
Val::constDispatch(this, v);
390398
}
@@ -407,9 +415,11 @@ Statement* OptInMutator::mutate(Val* v) {
407415
Statement* OptOutMutator::mutate(Statement* s) {
408416
return Statement::mutatorDispatch(this, s);
409417
}
418+
410419
Statement* OptOutMutator::mutate(Expr* e) {
411420
return Expr::mutatorDispatch(this, e);
412421
}
422+
413423
Statement* OptOutMutator::mutate(Val* v) {
414424
// If value is already mutated, return the mutation
415425
if (mutations.find(v) != mutations.end())

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,11 @@ struct TORCH_CUDA_API OptOutMutator {
331331
virtual Statement* mutate(Expr* e);
332332
virtual Statement* mutate(Val* v);
333333

334-
/*
335-
* We always want to dispatch through a Val, so we can capture and dispatch
336-
* correctly members of nodes like Split->TensorDomain If we don't call the
337-
* below function or manually cast to use mutate(Val* v) we can't intercept
338-
* and mutate by capturing mutate(Val* v), which is what we do when we want to
339-
* replace all instances of a value.
340-
*/
334+
// We always want to dispatch through a Val, so we can capture and dispatch
335+
// correctly members of nodes like Split->TensorDomain If we don't call the
336+
// below function or manually cast to use mutate(Val* v) we can't intercept
337+
// and mutate by capturing mutate(Val* v), which is what we do when we want to
338+
// replace all instances of a value.
341339
Statement* mutateAsVal(Val* v) {
342340
return mutate(v);
343341
}
@@ -352,7 +350,8 @@ struct TORCH_CUDA_API OptOutMutator {
352350

353351
std::unordered_map<Val*, Val*> mutations;
354352

355-
//****Functions below defined in mutator.cpp*****///
353+
//****Functions below defined in mutator.cpp*****
354+
356355
// Vals
357356
virtual Statement* mutate(IterDomain*);
358357
virtual Statement* mutate(TensorDomain*);

0 commit comments

Comments
 (0)