Skip to content

Commit 3745e70

Browse files
authored
[Tosa] Rename variables to coding style guideline (#69509)
This patch fixes variable names in the style guide. Specifically, names in the form xyz_abc are changed to the form xyzAbc Signed-off-by: Tai Ly <[email protected]>
1 parent 362b115 commit 3745e70

File tree

3 files changed

+81
-82
lines changed

3 files changed

+81
-82
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,24 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
5454
return rewriter.notifyMatchFailure(loc,
5555
"cannot rewrite as its already correct");
5656

57-
Value input1_copy = input1;
58-
Value input2_copy = input2;
59-
if (EqualizeRanks(rewriter, loc, input1_copy, input2_copy).failed()) {
57+
Value input1Copy = input1;
58+
Value input2Copy = input2;
59+
if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
6060
return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
6161
}
6262

6363
// Verify the rank agrees with the output type if the output type is ranked.
6464
if (outputType) {
6565
if (outputType.getRank() !=
66-
llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
66+
llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
6767
outputType.getRank() !=
68-
llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
68+
llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
6969
return rewriter.notifyMatchFailure(
7070
loc, "the reshaped type doesn't agrees with the ranked output type");
7171
}
7272

73-
input1 = input1_copy;
74-
input2 = input2_copy;
73+
input1 = input1Copy;
74+
input2 = input2Copy;
7575

7676
return success();
7777
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -39,58 +39,58 @@ using namespace mlir::tosa;
3939
namespace {
4040

4141
static LogicalResult checkConstantOperandPad(Operation *op) {
42-
if (auto pad_op = dyn_cast<tosa::PadOp>(op)) {
42+
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
4343
DenseElementsAttr paddings;
44-
if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings)))
44+
if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
4545
return op->emitOpError("padding of pad is not constant");
4646

47-
DenseElementsAttr pad_const;
48-
// Assume this op is zero-padding if pad_const is not presented.
49-
if (pad_op.getPadConst() &&
50-
!matchPattern(pad_op.getPadConst(), m_Constant(&pad_const)))
47+
DenseElementsAttr padConst;
48+
// Assume this op is zero-padding if padConst is not presented.
49+
if (padOp.getPadConst() &&
50+
!matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
5151
return op->emitOpError("pad_const of pad is not constant");
5252
}
5353
return success();
5454
}
5555

5656
static LogicalResult checkConstantOperandTranspose(Operation *op) {
57-
if (auto transpose_op = dyn_cast<tosa::TransposeOp>(op)) {
57+
if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
5858
DenseElementsAttr perms;
59-
if (!matchPattern(transpose_op.getPerms(), m_Constant(&perms)))
59+
if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
6060
return op->emitOpError("perms of transpose is not constant");
6161
}
6262
return success();
6363
}
6464

6565
static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
66-
if (auto fc_op = dyn_cast<tosa::FullyConnectedOp>(op)) {
66+
if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
6767
DenseElementsAttr weight;
68-
if (!matchPattern(fc_op.getWeight(), m_Constant(&weight)))
68+
if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
6969
return op->emitOpError("weight of fully_connected is not constant");
7070

7171
DenseElementsAttr bias;
72-
if (!matchPattern(fc_op.getBias(), m_Constant(&bias)))
72+
if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
7373
return op->emitOpError("bias of fully_connected is not constant");
7474
}
7575
return success();
7676
}
7777

78-
struct tosa_level_t {
78+
struct TosaLevel {
7979
int32_t MAX_RANK = 0;
8080
int32_t MAX_KERNEL = 0;
8181
int32_t MAX_STRIDE = 0;
8282
int32_t MAX_SCALE = 0;
8383

8484
// @todo: MAX_LOG2_SIZE value and checks
8585

86-
bool operator==(const tosa_level_t &rhs) {
86+
bool operator==(const TosaLevel &rhs) {
8787
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
8888
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
8989
}
9090
};
9191

92-
static constexpr tosa_level_t TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
93-
static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
92+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
93+
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
9494

9595
//===----------------------------------------------------------------------===//
9696
// TOSA Validation Pass.
@@ -108,7 +108,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
108108
void runOnOperation() final;
109109

110110
LogicalResult applyConstantOperandCheck(Operation *op) {
111-
for (auto &checker : const_checkers) {
111+
for (auto &checker : constCheckers) {
112112
if (failed(checker(op)))
113113
return failure();
114114
}
@@ -122,43 +122,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
122122

123123
private:
124124
void populateConstantOperandChecks() {
125-
const_checkers.emplace_back(checkConstantOperandPad);
126-
const_checkers.emplace_back(checkConstantOperandTranspose);
127-
const_checkers.emplace_back(checkConstantOperandFullyConnected);
125+
constCheckers.emplace_back(checkConstantOperandPad);
126+
constCheckers.emplace_back(checkConstantOperandTranspose);
127+
constCheckers.emplace_back(checkConstantOperandFullyConnected);
128128
}
129129

130130
bool levelCheckKernel(Operation *op, int32_t v,
131-
const std::string &check_desc) {
132-
if (v > tosa_level.MAX_KERNEL) {
133-
op->emitOpError() << "failed level check: " << check_desc;
131+
const std::string &checkDesc) {
132+
if (v > tosaLevel.MAX_KERNEL) {
133+
op->emitOpError() << "failed level check: " << checkDesc;
134134
return false;
135135
}
136136
return true;
137137
}
138138

139139
bool levelCheckStride(Operation *op, int32_t v,
140-
const std::string &check_desc) {
141-
if (v > tosa_level.MAX_STRIDE) {
142-
op->emitOpError() << "failed level check: " << check_desc;
140+
const std::string &checkDesc) {
141+
if (v > tosaLevel.MAX_STRIDE) {
142+
op->emitOpError() << "failed level check: " << checkDesc;
143143
return false;
144144
}
145145
return true;
146146
}
147147

148-
bool levelCheckScale(Operation *op, int32_t v,
149-
const std::string &check_desc) {
150-
if (v > tosa_level.MAX_SCALE) {
151-
op->emitOpError() << "failed level check: " << check_desc;
148+
bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
149+
if (v > tosaLevel.MAX_SCALE) {
150+
op->emitOpError() << "failed level check: " << checkDesc;
152151
return false;
153152
}
154153
return true;
155154
}
156155

157156
bool levelCheckRank(Operation *op, const Value &v,
158-
const std::string &check_desc) {
157+
const std::string &checkDesc) {
159158
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
160-
if (type.getRank() > tosa_level.MAX_RANK) {
161-
op->emitOpError() << "failed level check: " << check_desc;
159+
if (type.getRank() > tosaLevel.MAX_RANK) {
160+
op->emitOpError() << "failed level check: " << checkDesc;
162161
return false;
163162
}
164163
}
@@ -182,8 +181,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
182181
}
183182

184183
bool levelCheckRanks(Operation *op) {
185-
#define CHECK_RANKS_FOR(tosa_op) \
186-
if (!levelCheckRanksFor<tosa_op##Op>(op)) \
184+
#define CHECK_RANKS_FOR(tosaOp) \
185+
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
187186
return false;
188187

189188
// tensor operators:
@@ -257,18 +256,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
257256
// Pool Op: level check kernel/stride/pad values
258257
template <typename T>
259258
bool levelCheckPool(Operation *op) {
260-
if (auto pool_op = dyn_cast<T>(op)) {
261-
for (auto k : pool_op.getKernel()) {
259+
if (auto poolOp = dyn_cast<T>(op)) {
260+
for (auto k : poolOp.getKernel()) {
262261
if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
263262
return false;
264263
}
265264
}
266-
for (auto s : pool_op.getStride()) {
265+
for (auto s : poolOp.getStride()) {
267266
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
268267
return false;
269268
}
270269
}
271-
for (auto p : pool_op.getPad()) {
270+
for (auto p : poolOp.getPad()) {
272271
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
273272
return false;
274273
}
@@ -280,27 +279,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
280279
// Conv Op: level check dilation/stride/pad values
281280
template <typename T>
282281
bool levelCheckConv(Operation *op) {
283-
if (auto conv_op = dyn_cast<T>(op)) {
282+
if (auto convOp = dyn_cast<T>(op)) {
284283

285-
for (auto k : conv_op.getDilation()) {
284+
for (auto k : convOp.getDilation()) {
286285
if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
287286
return false;
288287
}
289288
}
290-
for (auto p : conv_op.getPad()) {
289+
for (auto p : convOp.getPad()) {
291290
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
292291
return false;
293292
}
294293
}
295-
for (auto s : conv_op.getStride()) {
294+
for (auto s : convOp.getStride()) {
296295
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
297296
return false;
298297
}
299298
}
300-
auto dilation = conv_op.getDilation();
301-
if (ShapedType weight_type =
299+
auto dilation = convOp.getDilation();
300+
if (ShapedType weightType =
302301
dyn_cast<ShapedType>(op->getOperand(1).getType())) {
303-
auto shape = weight_type.getShape();
302+
auto shape = weightType.getShape();
304303
if (isa<tosa::Conv2DOp>(op)) {
305304
assert(shape.size() == 4);
306305
assert(dilation.size() == 2);
@@ -354,9 +353,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
354353
// TransposeConv2d op: level check kH/kW, outpad, and stride
355354
bool levelCheckTransposeConv2d(Operation *op) {
356355
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
357-
if (ShapedType filter_type =
356+
if (ShapedType filterType =
358357
transpose.getFilter().getType().dyn_cast<ShapedType>()) {
359-
auto shape = filter_type.getShape();
358+
auto shape = filterType.getShape();
360359
assert(shape.size() == 4);
361360
// level check kernel sizes for kH and KW
362361
if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
@@ -382,13 +381,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
382381
bool levelCheckResize(Operation *op) {
383382
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
384383
auto scale = resize.getScale();
385-
int16_t scale_y_n = scale[0];
386-
int16_t scale_y_d = scale[1];
387-
int16_t scale_x_n = scale[2];
388-
int16_t scale_x_d = scale[3];
389-
if (!levelCheckScale(op, scale_y_n / scale_y_d,
384+
int16_t scaleYN = scale[0];
385+
int16_t scaleYD = scale[1];
386+
int16_t scaleXN = scale[2];
387+
int16_t scaleXD = scale[3];
388+
if (!levelCheckScale(op, scaleYN / scaleYD,
390389
"scale_y_n/scale_y_d <= MAX_SCALE") ||
391-
!levelCheckScale(op, scale_x_n / scale_x_d,
390+
!levelCheckScale(op, scaleXN / scaleXD,
392391
"scale_x_n/scale_x_d <= MAX_SCALE")) {
393392
return false;
394393
}
@@ -399,22 +398,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
399398
// configure profile and level values from pass options profileName and
400399
// levelName
401400
void configLevelAndProfile() {
402-
tosa_level = TOSA_LEVEL_NONE;
401+
tosaLevel = TOSA_LEVEL_NONE;
403402
if (level == TosaLevelEnum::EightK) {
404-
tosa_level = TOSA_LEVEL_EIGHTK;
403+
tosaLevel = TOSA_LEVEL_EIGHTK;
405404
}
406405
}
407406

408407
bool CheckVariable(Operation *op);
409408
bool CheckVariableReadOrWrite(Operation *op);
410409

411-
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
412-
tosa_level_t tosa_level;
413-
DenseMap<StringAttr, mlir::Type> variables_map;
410+
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
411+
TosaLevel tosaLevel;
412+
DenseMap<StringAttr, mlir::Type> variablesMap;
414413
};
415414

416415
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
417-
if (tosa_level == TOSA_LEVEL_NONE) {
416+
if (tosaLevel == TOSA_LEVEL_NONE) {
418417
// no need to do level checks
419418
return success();
420419
}
@@ -439,24 +438,24 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
439438
}
440439

441440
inline bool CompatibleTypes(const mlir::Type &type,
442-
const mlir::Type &declared_type) {
441+
const mlir::Type &declaredType) {
443442
// for now, simply use type equality comparison
444-
return type == declared_type;
443+
return type == declaredType;
445444
}
446445

447446
bool TosaValidation::CheckVariable(Operation *op) {
448447
if (isa<mlir::tosa::VariableOp>(op)) {
449-
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
448+
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
450449

451-
if (variables_map.count(name_attr)) {
450+
if (variablesMap.count(nameAttr)) {
452451
op->emitOpError() << "name has already been declared";
453452
return false;
454453
}
455454

456-
auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
457-
mlir::Type type = type_attr.getValue();
455+
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
456+
mlir::Type type = typeAttr.getValue();
458457

459-
variables_map[name_attr] = type;
458+
variablesMap[nameAttr] = type;
460459
}
461460

462461
return true;
@@ -465,26 +464,26 @@ bool TosaValidation::CheckVariable(Operation *op) {
465464
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
466465
if (isa<mlir::tosa::VariableReadOp>(op) ||
467466
isa<mlir::tosa::VariableWriteOp>(op)) {
468-
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
467+
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
469468

470-
if (!variables_map.count(name_attr)) {
469+
if (!variablesMap.count(nameAttr)) {
471470
op->emitOpError() << "name has not been declared";
472471
return false;
473472
}
474473

475-
auto var_type = variables_map[name_attr];
474+
auto varType = variablesMap[nameAttr];
476475

477476
for (auto v : op->getOperands()) {
478477
auto type = v.getType();
479-
if (!CompatibleTypes(type, var_type)) {
478+
if (!CompatibleTypes(type, varType)) {
480479
op->emitOpError() << "operand type does not equal variable type";
481480
return false;
482481
}
483482
}
484483

485484
for (auto v : op->getResults()) {
486485
auto type = v.getType();
487-
if (!CompatibleTypes(type, var_type)) {
486+
if (!CompatibleTypes(type, varType)) {
488487
op->emitOpError() << "result type does not equal variable type";
489488
return false;
490489
}

mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
107107
}
108108
}
109109

110-
#define GET_UQTYPE(input_type) \
111-
(llvm::dyn_cast<quant::UniformQuantizedType>((input_type).getElementType()))
112-
#define GET_QTYPE(input_type) \
113-
(llvm::dyn_cast<quant::QuantizedType>((input_type).getElementType()))
110+
#define GET_UQTYPE(inputType) \
111+
(llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
112+
#define GET_QTYPE(inputType) \
113+
(llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
114114

115115
/// Method to build ConvOpQuantizationAttr, called from
116116
/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:

0 commit comments

Comments
 (0)