@@ -39,58 +39,58 @@ using namespace mlir::tosa;
39
39
namespace {
40
40
41
41
static LogicalResult checkConstantOperandPad (Operation *op) {
42
- if (auto pad_op = dyn_cast<tosa::PadOp>(op)) {
42
+ if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
43
43
DenseElementsAttr paddings;
44
- if (!matchPattern (pad_op .getPadding (), m_Constant (&paddings)))
44
+ if (!matchPattern (padOp .getPadding (), m_Constant (&paddings)))
45
45
return op->emitOpError (" padding of pad is not constant" );
46
46
47
- DenseElementsAttr pad_const ;
48
- // Assume this op is zero-padding if pad_const is not presented.
49
- if (pad_op .getPadConst () &&
50
- !matchPattern (pad_op .getPadConst (), m_Constant (&pad_const )))
47
+ DenseElementsAttr padConst ;
48
+ // Assume this op is zero-padding if padConst is not presented.
49
+ if (padOp .getPadConst () &&
50
+ !matchPattern (padOp .getPadConst (), m_Constant (&padConst )))
51
51
return op->emitOpError (" pad_const of pad is not constant" );
52
52
}
53
53
return success ();
54
54
}
55
55
56
56
static LogicalResult checkConstantOperandTranspose (Operation *op) {
57
- if (auto transpose_op = dyn_cast<tosa::TransposeOp>(op)) {
57
+ if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
58
58
DenseElementsAttr perms;
59
- if (!matchPattern (transpose_op .getPerms (), m_Constant (&perms)))
59
+ if (!matchPattern (transposeOp .getPerms (), m_Constant (&perms)))
60
60
return op->emitOpError (" perms of transpose is not constant" );
61
61
}
62
62
return success ();
63
63
}
64
64
65
65
static LogicalResult checkConstantOperandFullyConnected (Operation *op) {
66
- if (auto fc_op = dyn_cast<tosa::FullyConnectedOp>(op)) {
66
+ if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
67
67
DenseElementsAttr weight;
68
- if (!matchPattern (fc_op .getWeight (), m_Constant (&weight)))
68
+ if (!matchPattern (fcOp .getWeight (), m_Constant (&weight)))
69
69
return op->emitOpError (" weight of fully_connected is not constant" );
70
70
71
71
DenseElementsAttr bias;
72
- if (!matchPattern (fc_op .getBias (), m_Constant (&bias)))
72
+ if (!matchPattern (fcOp .getBias (), m_Constant (&bias)))
73
73
return op->emitOpError (" bias of fully_connected is not constant" );
74
74
}
75
75
return success ();
76
76
}
77
77
78
- struct tosa_level_t {
78
+ struct TosaLevel {
79
79
int32_t MAX_RANK = 0 ;
80
80
int32_t MAX_KERNEL = 0 ;
81
81
int32_t MAX_STRIDE = 0 ;
82
82
int32_t MAX_SCALE = 0 ;
83
83
84
84
// @todo: MAX_LOG2_SIZE value and checks
85
85
86
- bool operator ==(const tosa_level_t &rhs) {
86
+ bool operator ==(const TosaLevel &rhs) {
87
87
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
88
88
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE ;
89
89
}
90
90
};
91
91
92
- static constexpr tosa_level_t TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
93
- static constexpr tosa_level_t TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
92
+ static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6 , 8192 , 8192 , 256 };
93
+ static constexpr TosaLevel TOSA_LEVEL_NONE = {0 , 0 , 0 , 0 };
94
94
95
95
// ===----------------------------------------------------------------------===//
96
96
// TOSA Validation Pass.
@@ -108,7 +108,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
108
108
void runOnOperation () final ;
109
109
110
110
LogicalResult applyConstantOperandCheck (Operation *op) {
111
- for (auto &checker : const_checkers ) {
111
+ for (auto &checker : constCheckers ) {
112
112
if (failed (checker (op)))
113
113
return failure ();
114
114
}
@@ -122,43 +122,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
122
122
123
123
private:
124
124
void populateConstantOperandChecks () {
125
- const_checkers .emplace_back (checkConstantOperandPad);
126
- const_checkers .emplace_back (checkConstantOperandTranspose);
127
- const_checkers .emplace_back (checkConstantOperandFullyConnected);
125
+ constCheckers .emplace_back (checkConstantOperandPad);
126
+ constCheckers .emplace_back (checkConstantOperandTranspose);
127
+ constCheckers .emplace_back (checkConstantOperandFullyConnected);
128
128
}
129
129
130
130
bool levelCheckKernel (Operation *op, int32_t v,
131
- const std::string &check_desc ) {
132
- if (v > tosa_level .MAX_KERNEL ) {
133
- op->emitOpError () << " failed level check: " << check_desc ;
131
+ const std::string &checkDesc ) {
132
+ if (v > tosaLevel .MAX_KERNEL ) {
133
+ op->emitOpError () << " failed level check: " << checkDesc ;
134
134
return false ;
135
135
}
136
136
return true ;
137
137
}
138
138
139
139
bool levelCheckStride (Operation *op, int32_t v,
140
- const std::string &check_desc ) {
141
- if (v > tosa_level .MAX_STRIDE ) {
142
- op->emitOpError () << " failed level check: " << check_desc ;
140
+ const std::string &checkDesc ) {
141
+ if (v > tosaLevel .MAX_STRIDE ) {
142
+ op->emitOpError () << " failed level check: " << checkDesc ;
143
143
return false ;
144
144
}
145
145
return true ;
146
146
}
147
147
148
- bool levelCheckScale (Operation *op, int32_t v,
149
- const std::string &check_desc) {
150
- if (v > tosa_level.MAX_SCALE ) {
151
- op->emitOpError () << " failed level check: " << check_desc;
148
+ bool levelCheckScale (Operation *op, int32_t v, const std::string &checkDesc) {
149
+ if (v > tosaLevel.MAX_SCALE ) {
150
+ op->emitOpError () << " failed level check: " << checkDesc;
152
151
return false ;
153
152
}
154
153
return true ;
155
154
}
156
155
157
156
bool levelCheckRank (Operation *op, const Value &v,
158
- const std::string &check_desc ) {
157
+ const std::string &checkDesc ) {
159
158
if (ShapedType type = dyn_cast<ShapedType>(v.getType ())) {
160
- if (type.getRank () > tosa_level .MAX_RANK ) {
161
- op->emitOpError () << " failed level check: " << check_desc ;
159
+ if (type.getRank () > tosaLevel .MAX_RANK ) {
160
+ op->emitOpError () << " failed level check: " << checkDesc ;
162
161
return false ;
163
162
}
164
163
}
@@ -182,8 +181,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
182
181
}
183
182
184
183
bool levelCheckRanks (Operation *op) {
185
- #define CHECK_RANKS_FOR (tosa_op ) \
186
- if (!levelCheckRanksFor<tosa_op ##Op>(op)) \
184
+ #define CHECK_RANKS_FOR (tosaOp ) \
185
+ if (!levelCheckRanksFor<tosaOp ##Op>(op)) \
187
186
return false ;
188
187
189
188
// tensor operators:
@@ -257,18 +256,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
257
256
// Pool Op: level check kernel/stride/pad values
258
257
template <typename T>
259
258
bool levelCheckPool (Operation *op) {
260
- if (auto pool_op = dyn_cast<T>(op)) {
261
- for (auto k : pool_op .getKernel ()) {
259
+ if (auto poolOp = dyn_cast<T>(op)) {
260
+ for (auto k : poolOp .getKernel ()) {
262
261
if (!levelCheckKernel (op, k, " kernel <= MAX_KERNEL" )) {
263
262
return false ;
264
263
}
265
264
}
266
- for (auto s : pool_op .getStride ()) {
265
+ for (auto s : poolOp .getStride ()) {
267
266
if (!levelCheckStride (op, s, " stride <= MAX_STRIDE" )) {
268
267
return false ;
269
268
}
270
269
}
271
- for (auto p : pool_op .getPad ()) {
270
+ for (auto p : poolOp .getPad ()) {
272
271
if (!levelCheckKernel (op, p, " pad <= MAX_KERNEL" )) {
273
272
return false ;
274
273
}
@@ -280,27 +279,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
280
279
// Conv Op: level check dilation/stride/pad values
281
280
template <typename T>
282
281
bool levelCheckConv (Operation *op) {
283
- if (auto conv_op = dyn_cast<T>(op)) {
282
+ if (auto convOp = dyn_cast<T>(op)) {
284
283
285
- for (auto k : conv_op .getDilation ()) {
284
+ for (auto k : convOp .getDilation ()) {
286
285
if (!levelCheckKernel (op, k, " dilation <= MAX_KERNEL" )) {
287
286
return false ;
288
287
}
289
288
}
290
- for (auto p : conv_op .getPad ()) {
289
+ for (auto p : convOp .getPad ()) {
291
290
if (!levelCheckKernel (op, p, " pad <= MAX_KERNEL" )) {
292
291
return false ;
293
292
}
294
293
}
295
- for (auto s : conv_op .getStride ()) {
294
+ for (auto s : convOp .getStride ()) {
296
295
if (!levelCheckStride (op, s, " stride <= MAX_STRIDE" )) {
297
296
return false ;
298
297
}
299
298
}
300
- auto dilation = conv_op .getDilation ();
301
- if (ShapedType weight_type =
299
+ auto dilation = convOp .getDilation ();
300
+ if (ShapedType weightType =
302
301
dyn_cast<ShapedType>(op->getOperand (1 ).getType ())) {
303
- auto shape = weight_type .getShape ();
302
+ auto shape = weightType .getShape ();
304
303
if (isa<tosa::Conv2DOp>(op)) {
305
304
assert (shape.size () == 4 );
306
305
assert (dilation.size () == 2 );
@@ -354,9 +353,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
354
353
// TransposeConv2d op: level check kH/kW, outpad, and stride
355
354
bool levelCheckTransposeConv2d (Operation *op) {
356
355
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
357
- if (ShapedType filter_type =
356
+ if (ShapedType filterType =
358
357
transpose.getFilter ().getType ().dyn_cast <ShapedType>()) {
359
- auto shape = filter_type .getShape ();
358
+ auto shape = filterType .getShape ();
360
359
assert (shape.size () == 4 );
361
360
// level check kernel sizes for kH and KW
362
361
if (!levelCheckKernel (op, shape[1 ], " KH <= MAX_KERNEL" ) ||
@@ -382,13 +381,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
382
381
bool levelCheckResize (Operation *op) {
383
382
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
384
383
auto scale = resize.getScale ();
385
- int16_t scale_y_n = scale[0 ];
386
- int16_t scale_y_d = scale[1 ];
387
- int16_t scale_x_n = scale[2 ];
388
- int16_t scale_x_d = scale[3 ];
389
- if (!levelCheckScale (op, scale_y_n / scale_y_d ,
384
+ int16_t scaleYN = scale[0 ];
385
+ int16_t scaleYD = scale[1 ];
386
+ int16_t scaleXN = scale[2 ];
387
+ int16_t scaleXD = scale[3 ];
388
+ if (!levelCheckScale (op, scaleYN / scaleYD ,
390
389
" scale_y_n/scale_y_d <= MAX_SCALE" ) ||
391
- !levelCheckScale (op, scale_x_n / scale_x_d ,
390
+ !levelCheckScale (op, scaleXN / scaleXD ,
392
391
" scale_x_n/scale_x_d <= MAX_SCALE" )) {
393
392
return false ;
394
393
}
@@ -399,22 +398,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
399
398
// configure profile and level values from pass options profileName and
400
399
// levelName
401
400
void configLevelAndProfile () {
402
- tosa_level = TOSA_LEVEL_NONE;
401
+ tosaLevel = TOSA_LEVEL_NONE;
403
402
if (level == TosaLevelEnum::EightK) {
404
- tosa_level = TOSA_LEVEL_EIGHTK;
403
+ tosaLevel = TOSA_LEVEL_EIGHTK;
405
404
}
406
405
}
407
406
408
407
bool CheckVariable (Operation *op);
409
408
bool CheckVariableReadOrWrite (Operation *op);
410
409
411
- SmallVector<std::function<LogicalResult(Operation *)>> const_checkers ;
412
- tosa_level_t tosa_level ;
413
- DenseMap<StringAttr, mlir::Type> variables_map ;
410
+ SmallVector<std::function<LogicalResult(Operation *)>> constCheckers ;
411
+ TosaLevel tosaLevel ;
412
+ DenseMap<StringAttr, mlir::Type> variablesMap ;
414
413
};
415
414
416
415
LogicalResult TosaValidation::applyLevelCheck (Operation *op) {
417
- if (tosa_level == TOSA_LEVEL_NONE) {
416
+ if (tosaLevel == TOSA_LEVEL_NONE) {
418
417
// no need to do level checks
419
418
return success ();
420
419
}
@@ -439,24 +438,24 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439
438
}
440
439
441
440
inline bool CompatibleTypes (const mlir::Type &type,
442
- const mlir::Type &declared_type ) {
441
+ const mlir::Type &declaredType ) {
443
442
// for now, simply use type equality comparison
444
- return type == declared_type ;
443
+ return type == declaredType ;
445
444
}
446
445
447
446
bool TosaValidation::CheckVariable (Operation *op) {
448
447
if (isa<mlir::tosa::VariableOp>(op)) {
449
- auto name_attr = cast<mlir::StringAttr>(op->getAttr (" name" ));
448
+ auto nameAttr = cast<mlir::StringAttr>(op->getAttr (" name" ));
450
449
451
- if (variables_map .count (name_attr )) {
450
+ if (variablesMap .count (nameAttr )) {
452
451
op->emitOpError () << " name has already been declared" ;
453
452
return false ;
454
453
}
455
454
456
- auto type_attr = cast<mlir::TypeAttr>(op->getAttr (" type" ));
457
- mlir::Type type = type_attr .getValue ();
455
+ auto typeAttr = cast<mlir::TypeAttr>(op->getAttr (" type" ));
456
+ mlir::Type type = typeAttr .getValue ();
458
457
459
- variables_map[name_attr ] = type;
458
+ variablesMap[nameAttr ] = type;
460
459
}
461
460
462
461
return true ;
@@ -465,26 +464,26 @@ bool TosaValidation::CheckVariable(Operation *op) {
465
464
bool TosaValidation::CheckVariableReadOrWrite (Operation *op) {
466
465
if (isa<mlir::tosa::VariableReadOp>(op) ||
467
466
isa<mlir::tosa::VariableWriteOp>(op)) {
468
- auto name_attr = cast<mlir::StringAttr>(op->getAttr (" name" ));
467
+ auto nameAttr = cast<mlir::StringAttr>(op->getAttr (" name" ));
469
468
470
- if (!variables_map .count (name_attr )) {
469
+ if (!variablesMap .count (nameAttr )) {
471
470
op->emitOpError () << " name has not been declared" ;
472
471
return false ;
473
472
}
474
473
475
- auto var_type = variables_map[name_attr ];
474
+ auto varType = variablesMap[nameAttr ];
476
475
477
476
for (auto v : op->getOperands ()) {
478
477
auto type = v.getType ();
479
- if (!CompatibleTypes (type, var_type )) {
478
+ if (!CompatibleTypes (type, varType )) {
480
479
op->emitOpError () << " operand type does not equal variable type" ;
481
480
return false ;
482
481
}
483
482
}
484
483
485
484
for (auto v : op->getResults ()) {
486
485
auto type = v.getType ();
487
- if (!CompatibleTypes (type, var_type )) {
486
+ if (!CompatibleTypes (type, varType )) {
488
487
op->emitOpError () << " result type does not equal variable type" ;
489
488
return false ;
490
489
}
0 commit comments