@@ -36,7 +36,7 @@ T* ptr(T* obj) {
36
36
* }
37
37
*
38
38
* And therefore dispatch should never call:
39
- * ptr(mutator)->handle(static_cast <Statement*>(this ));
39
+ * ptr(mutator)->handle(this->as <Statement>( ));
40
40
*/
41
41
42
42
template <typename T>
@@ -45,35 +45,35 @@ void Val::dispatch(T handler, Val* val) {
45
45
case ValType::Scalar:
46
46
switch (*(val->getDataType ())) {
47
47
case DataType::Bool:
48
- ptr (handler)->handle (static_cast <Bool*>(val ));
48
+ ptr (handler)->handle (val-> as <Bool>( ));
49
49
return ;
50
50
case DataType::Float:
51
- ptr (handler)->handle (static_cast <Float*>(val ));
51
+ ptr (handler)->handle (val-> as <Float>( ));
52
52
return ;
53
53
case DataType::Half:
54
- ptr (handler)->handle (static_cast <Half*>(val ));
54
+ ptr (handler)->handle (val-> as <Half>( ));
55
55
return ;
56
56
case DataType::Int:
57
- ptr (handler)->handle (static_cast <Int*>(val ));
57
+ ptr (handler)->handle (val-> as <Int>( ));
58
58
return ;
59
59
default :
60
60
break ;
61
61
}
62
62
break ;
63
63
case ValType::IterDomain:
64
- ptr (handler)->handle (static_cast <IterDomain*>(val ));
64
+ ptr (handler)->handle (val-> as <IterDomain>( ));
65
65
return ;
66
66
case ValType::TensorDomain:
67
- ptr (handler)->handle (static_cast <TensorDomain*>(val ));
67
+ ptr (handler)->handle (val-> as <TensorDomain>( ));
68
68
return ;
69
69
case ValType::TensorView:
70
- ptr (handler)->handle (static_cast <TensorView*>(val ));
70
+ ptr (handler)->handle (val-> as <TensorView>( ));
71
71
return ;
72
72
case ValType::TensorIndex:
73
- ptr (handler)->handle (static_cast <TensorIndex*>(val ));
73
+ ptr (handler)->handle (val-> as <TensorIndex>( ));
74
74
return ;
75
75
case ValType::NamedScalar:
76
- ptr (handler)->handle (static_cast <NamedScalar*>(val ));
76
+ ptr (handler)->handle (val-> as <NamedScalar>( ));
77
77
return ;
78
78
default :
79
79
break ;
@@ -85,34 +85,34 @@ template <typename T>
85
85
void Expr::dispatch (T handler, Expr* expr) {
86
86
switch (*(expr->getExprType ())) {
87
87
case ExprType::Split:
88
- ptr (handler)->handle (static_cast <Split*>(expr ));
88
+ ptr (handler)->handle (expr-> as <Split>( ));
89
89
return ;
90
90
case ExprType::Merge:
91
- ptr (handler)->handle (static_cast <Merge*>(expr ));
91
+ ptr (handler)->handle (expr-> as <Merge>( ));
92
92
return ;
93
93
case ExprType::UnaryOp:
94
- ptr (handler)->handle (static_cast <UnaryOp*>(expr ));
94
+ ptr (handler)->handle (expr-> as <UnaryOp>( ));
95
95
return ;
96
96
case ExprType::BinaryOp:
97
- ptr (handler)->handle (static_cast <BinaryOp*>(expr ));
97
+ ptr (handler)->handle (expr-> as <BinaryOp>( ));
98
98
return ;
99
99
case ExprType::TernaryOp:
100
- ptr (handler)->handle (static_cast <TernaryOp*>(expr ));
100
+ ptr (handler)->handle (expr-> as <TernaryOp>( ));
101
101
return ;
102
102
case ExprType::ReductionOp:
103
- ptr (handler)->handle (static_cast <ReductionOp*>(expr ));
103
+ ptr (handler)->handle (expr-> as <ReductionOp>( ));
104
104
return ;
105
105
case ExprType::BroadcastOp:
106
- ptr (handler)->handle (static_cast <BroadcastOp*>(expr ));
106
+ ptr (handler)->handle (expr-> as <BroadcastOp>( ));
107
107
return ;
108
108
case ExprType::ForLoop:
109
- ptr (handler)->handle (static_cast <ForLoop*>(expr ));
109
+ ptr (handler)->handle (expr-> as <ForLoop>( ));
110
110
return ;
111
111
case ExprType::IfThenElse:
112
- ptr (handler)->handle (static_cast <IfThenElse*>(expr ));
112
+ ptr (handler)->handle (expr-> as <IfThenElse>( ));
113
113
return ;
114
114
case ExprType::Allocate:
115
- ptr (handler)->handle (static_cast <Allocate*>(expr ));
115
+ ptr (handler)->handle (expr-> as <Allocate>( ));
116
116
return ;
117
117
default :
118
118
TORCH_INTERNAL_ASSERT (false , " Unknown exprtype in dispatch!" );
@@ -122,9 +122,9 @@ void Expr::dispatch(T handler, Expr* expr) {
122
122
template <typename T>
123
123
void Statement::dispatch (T handler, Statement* stmt) {
124
124
if (stmt->isVal ()) {
125
- ptr (handler)->handle (static_cast <Val*>(stmt ));
125
+ ptr (handler)->handle (stmt-> as <Val>( ));
126
126
} else if (stmt->isExpr ()) {
127
- ptr (handler)->handle (static_cast <Expr*>(stmt ));
127
+ ptr (handler)->handle (stmt-> as <Expr>( ));
128
128
} else
129
129
TORCH_INTERNAL_ASSERT (false , " Unknown stmttype in dispatch!" );
130
130
}
@@ -135,35 +135,35 @@ void Val::constDispatch(T handler, const Val* val) {
135
135
case ValType::Scalar:
136
136
switch (*(val->getDataType ())) {
137
137
case DataType::Bool:
138
- ptr (handler)->handle (static_cast < const Bool*>(val ));
138
+ ptr (handler)->handle (val-> as < Bool>( ));
139
139
return ;
140
140
case DataType::Float:
141
- ptr (handler)->handle (static_cast < const Float*>(val ));
141
+ ptr (handler)->handle (val-> as < Float>( ));
142
142
return ;
143
143
case DataType::Half:
144
- ptr (handler)->handle (static_cast < const Half*>(val ));
144
+ ptr (handler)->handle (val-> as < Half>( ));
145
145
return ;
146
146
case DataType::Int:
147
- ptr (handler)->handle (static_cast < const Int*>(val ));
147
+ ptr (handler)->handle (val-> as < Int>( ));
148
148
return ;
149
149
default :
150
150
break ;
151
151
}
152
152
break ;
153
153
case ValType::IterDomain:
154
- ptr (handler)->handle (static_cast < const IterDomain*>(val ));
154
+ ptr (handler)->handle (val-> as < IterDomain>( ));
155
155
return ;
156
156
case ValType::TensorDomain:
157
- ptr (handler)->handle (static_cast < const TensorDomain*>(val ));
157
+ ptr (handler)->handle (val-> as < TensorDomain>( ));
158
158
return ;
159
159
case ValType::TensorView:
160
- ptr (handler)->handle (static_cast < const TensorView*>(val ));
160
+ ptr (handler)->handle (val-> as < TensorView>( ));
161
161
return ;
162
162
case ValType::TensorIndex:
163
- ptr (handler)->handle (static_cast < const TensorIndex*>(val ));
163
+ ptr (handler)->handle (val-> as < TensorIndex>( ));
164
164
return ;
165
165
case ValType::NamedScalar:
166
- ptr (handler)->handle (static_cast < const NamedScalar*>(val ));
166
+ ptr (handler)->handle (val-> as < NamedScalar>( ));
167
167
return ;
168
168
default :
169
169
break ;
@@ -175,34 +175,34 @@ template <typename T>
175
175
void Expr::constDispatch (T handler, const Expr* expr) {
176
176
switch (*(expr->getExprType ())) {
177
177
case ExprType::Split:
178
- ptr (handler)->handle (static_cast < const Split*>(expr ));
178
+ ptr (handler)->handle (expr-> as < Split>( ));
179
179
return ;
180
180
case ExprType::Merge:
181
- ptr (handler)->handle (static_cast < const Merge*>(expr ));
181
+ ptr (handler)->handle (expr-> as < Merge>( ));
182
182
return ;
183
183
case ExprType::UnaryOp:
184
- ptr (handler)->handle (static_cast < const UnaryOp*>(expr ));
184
+ ptr (handler)->handle (expr-> as < UnaryOp>( ));
185
185
return ;
186
186
case ExprType::BinaryOp:
187
- ptr (handler)->handle (static_cast < const BinaryOp*>(expr ));
187
+ ptr (handler)->handle (expr-> as < BinaryOp>( ));
188
188
return ;
189
189
case ExprType::TernaryOp:
190
- ptr (handler)->handle (static_cast < const TernaryOp*>(expr ));
190
+ ptr (handler)->handle (expr-> as < TernaryOp>( ));
191
191
return ;
192
192
case ExprType::ReductionOp:
193
- ptr (handler)->handle (static_cast < const ReductionOp*>(expr ));
193
+ ptr (handler)->handle (expr-> as < ReductionOp>( ));
194
194
return ;
195
195
case ExprType::BroadcastOp:
196
- ptr (handler)->handle (static_cast < const BroadcastOp*>(expr ));
196
+ ptr (handler)->handle (expr-> as < BroadcastOp>( ));
197
197
return ;
198
198
case ExprType::ForLoop:
199
- ptr (handler)->handle (static_cast < const ForLoop*>(expr ));
199
+ ptr (handler)->handle (expr-> as < ForLoop>( ));
200
200
return ;
201
201
case ExprType::IfThenElse:
202
- ptr (handler)->handle (static_cast < const IfThenElse*>(expr ));
202
+ ptr (handler)->handle (expr-> as < IfThenElse>( ));
203
203
return ;
204
204
case ExprType::Allocate:
205
- ptr (handler)->handle (static_cast < const Allocate*>(expr ));
205
+ ptr (handler)->handle (expr-> as < Allocate>( ));
206
206
return ;
207
207
default :
208
208
TORCH_INTERNAL_ASSERT (false , " Unknown exprtype in dispatch!" );
@@ -212,9 +212,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
212
212
template <typename T>
213
213
void Statement::constDispatch (T handler, const Statement* stmt) {
214
214
if (stmt->isVal ()) {
215
- ptr (handler)->handle (static_cast < const Val*>(stmt ));
215
+ ptr (handler)->handle (stmt-> as < Val>( ));
216
216
} else if (stmt->isExpr ()) {
217
- ptr (handler)->handle (static_cast < const Expr*>(stmt ));
217
+ ptr (handler)->handle (stmt-> as < Expr>( ));
218
218
} else
219
219
TORCH_INTERNAL_ASSERT (false , " Unknown stmttype in dispatch!" );
220
220
}
@@ -228,35 +228,35 @@ void Statement::constDispatch(T handler, const Statement* stmt) {
228
228
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
229
229
* }
230
230
* And therefore dispatch should never call:
231
- * ptr(mutator)->mutate(static_cast <Statement*>(this ));
231
+ * ptr(mutator)->mutate(this->as <Statement>( ));
232
232
*/
233
233
template <typename T>
234
234
Statement* Val::mutatorDispatch (T mutator, Val* val) {
235
235
switch (*(val->getValType ())) {
236
236
case ValType::Scalar:
237
237
switch (*(val->getDataType ())) {
238
238
case DataType::Bool:
239
- return ptr (mutator)->mutate (static_cast <Bool*>(val ));
239
+ return ptr (mutator)->mutate (val-> as <Bool>( ));
240
240
case DataType::Float:
241
- return ptr (mutator)->mutate (static_cast <Float*>(val ));
241
+ return ptr (mutator)->mutate (val-> as <Float>( ));
242
242
case DataType::Half:
243
- return ptr (mutator)->mutate (static_cast <Half*>(val ));
243
+ return ptr (mutator)->mutate (val-> as <Half>( ));
244
244
case DataType::Int:
245
- return ptr (mutator)->mutate (static_cast <Int*>(val ));
245
+ return ptr (mutator)->mutate (val-> as <Int>( ));
246
246
default :
247
247
break ;
248
248
}
249
249
break ;
250
250
case ValType::IterDomain:
251
- return ptr (mutator)->mutate (static_cast <IterDomain*>(val ));
251
+ return ptr (mutator)->mutate (val-> as <IterDomain>( ));
252
252
case ValType::TensorDomain:
253
- return ptr (mutator)->mutate (static_cast <TensorDomain*>(val ));
253
+ return ptr (mutator)->mutate (val-> as <TensorDomain>( ));
254
254
case ValType::TensorView:
255
- return ptr (mutator)->mutate (static_cast <TensorView*>(val ));
255
+ return ptr (mutator)->mutate (val-> as <TensorView>( ));
256
256
case ValType::TensorIndex:
257
- return ptr (mutator)->mutate (static_cast <TensorIndex*>(val ));
257
+ return ptr (mutator)->mutate (val-> as <TensorIndex>( ));
258
258
case ValType::NamedScalar:
259
- return ptr (mutator)->mutate (static_cast <NamedScalar*>(val ));
259
+ return ptr (mutator)->mutate (val-> as <NamedScalar>( ));
260
260
default :
261
261
break ;
262
262
}
@@ -267,25 +267,25 @@ template <typename T>
267
267
Statement* Expr::mutatorDispatch (T mutator, Expr* expr) {
268
268
switch (*(expr->getExprType ())) {
269
269
case ExprType::Split:
270
- return ptr (mutator)->mutate (static_cast <Split*>(expr ));
270
+ return ptr (mutator)->mutate (expr-> as <Split>( ));
271
271
case ExprType::Merge:
272
- return ptr (mutator)->mutate (static_cast <Merge*>(expr ));
272
+ return ptr (mutator)->mutate (expr-> as <Merge>( ));
273
273
case ExprType::UnaryOp:
274
- return ptr (mutator)->mutate (static_cast <UnaryOp*>(expr ));
274
+ return ptr (mutator)->mutate (expr-> as <UnaryOp>( ));
275
275
case ExprType::BinaryOp:
276
- return ptr (mutator)->mutate (static_cast <BinaryOp*>(expr ));
276
+ return ptr (mutator)->mutate (expr-> as <BinaryOp>( ));
277
277
case ExprType::TernaryOp:
278
- return ptr (mutator)->mutate (static_cast <TernaryOp*>(expr ));
278
+ return ptr (mutator)->mutate (expr-> as <TernaryOp>( ));
279
279
case ExprType::ReductionOp:
280
- return ptr (mutator)->mutate (static_cast <ReductionOp*>(expr ));
280
+ return ptr (mutator)->mutate (expr-> as <ReductionOp>( ));
281
281
case ExprType::BroadcastOp:
282
- return ptr (mutator)->mutate (static_cast <BroadcastOp*>(expr ));
282
+ return ptr (mutator)->mutate (expr-> as <BroadcastOp>( ));
283
283
case ExprType::ForLoop:
284
- return ptr (mutator)->mutate (static_cast <ForLoop*>(expr ));
284
+ return ptr (mutator)->mutate (expr-> as <ForLoop>( ));
285
285
case ExprType::IfThenElse:
286
- return ptr (mutator)->mutate (static_cast <IfThenElse*>(expr ));
286
+ return ptr (mutator)->mutate (expr-> as <IfThenElse>( ));
287
287
case ExprType::Allocate:
288
- return ptr (mutator)->mutate (static_cast <Allocate*>(expr ));
288
+ return ptr (mutator)->mutate (expr-> as <Allocate>( ));
289
289
default :
290
290
TORCH_INTERNAL_ASSERT (false , " Unknown exprtype in dispatch!" );
291
291
}
@@ -294,10 +294,10 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
294
294
template <typename T>
295
295
Statement* Statement::mutatorDispatch (T mutator, Statement* stmt) {
296
296
if (stmt->isVal ()) {
297
- return ptr (mutator)->mutate (static_cast <Val*>(stmt ));
297
+ return ptr (mutator)->mutate (stmt-> as <Val>( ));
298
298
}
299
299
if (stmt->isExpr ()) {
300
- return ptr (mutator)->mutate (static_cast <Expr*>(stmt ));
300
+ return ptr (mutator)->mutate (stmt-> as <Expr>( ));
301
301
}
302
302
TORCH_INTERNAL_ASSERT (false , " Unknown stmttype in dispatch!" );
303
303
}
@@ -352,39 +352,47 @@ template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*);
352
352
void OptOutDispatch::handle (Statement* s) {
353
353
Statement::dispatch (this , s);
354
354
}
355
+
355
356
void OptOutDispatch::handle (Expr* e) {
356
357
Expr::dispatch (this , e);
357
358
}
359
+
358
360
void OptOutDispatch::handle (Val* v) {
359
361
Val::dispatch (this , v);
360
362
}
361
363
362
364
void OptInDispatch::handle (Statement* s) {
363
365
Statement::dispatch (this , s);
364
366
}
367
+
365
368
void OptInDispatch::handle (Expr* e) {
366
369
Expr::dispatch (this , e);
367
370
}
371
+
368
372
void OptInDispatch::handle (Val* v) {
369
373
Val::dispatch (this , v);
370
374
}
371
375
372
376
void OptOutConstDispatch::handle (const Statement* s) {
373
377
Statement::constDispatch (this , s);
374
378
}
379
+
375
380
void OptOutConstDispatch::handle (const Expr* e) {
376
381
Expr::constDispatch (this , e);
377
382
}
383
+
378
384
void OptOutConstDispatch::handle (const Val* v) {
379
385
Val::constDispatch (this , v);
380
386
}
381
387
382
388
void OptInConstDispatch::handle (const Statement* s) {
383
389
Statement::constDispatch (this , s);
384
390
}
391
+
385
392
void OptInConstDispatch::handle (const Expr* e) {
386
393
Expr::constDispatch (this , e);
387
394
}
395
+
388
396
void OptInConstDispatch::handle (const Val* v) {
389
397
Val::constDispatch (this , v);
390
398
}
@@ -407,9 +415,11 @@ Statement* OptInMutator::mutate(Val* v) {
407
415
Statement* OptOutMutator::mutate (Statement* s) {
408
416
return Statement::mutatorDispatch (this , s);
409
417
}
418
+
410
419
Statement* OptOutMutator::mutate (Expr* e) {
411
420
return Expr::mutatorDispatch (this , e);
412
421
}
422
+
413
423
Statement* OptOutMutator::mutate (Val* v) {
414
424
// If value is already mutated, return the mutation
415
425
if (mutations.find (v) != mutations.end ())
0 commit comments