diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td index d761743a82bf8..39d24595ec1c4 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td @@ -58,8 +58,8 @@ def MeshSharding : AttrDef { let parameters = (ins AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster, - ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes, - OptionalArrayRefParameter<"int8_t">:$partial_axes, + ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes, + OptionalArrayRefParameter<"int32_t">:$partial_axes, OptionalParameter<"::mlir::mesh::Partial">:$partial_type ); diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 8ca4b66531042..a8aa0a694bee2 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> { }]; let arguments = (ins SymbolNameAttr:$sym_name, - I8Attr:$rank, + I64Attr:$rank, DefaultValuedAttr:$dim_sizes ); let assemblyFormat = [{ diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 379392ace4696..f1fabf95a68b7 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -350,8 +350,7 @@ template ::value && !std::is_convertible::value && !std::is_convertible::value && - !llvm::is_one_of::value, + !llvm::is_one_of::value, T> * = nullptr> inline std::enable_if_t::value, AsmPrinterT &> @@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) { return p << (value ? StringRef("true") : "false"); } -/// Specialization for 8-bit integers to ensure values are printed as integers -// and not characters. -template < - typename AsmPrinterT, typename T, - std::enable_if_t::value, T> * = nullptr> -inline std::enable_if_t::value, - AsmPrinterT &> -operator<<(AsmPrinterT &p, T value) { - return p << static_cast(value); -} - template inline std::enable_if_t::value, AsmPrinterT &> diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index b2a4710252875..fc91fd994f12d 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, LogicalResult ClusterOp::verify() { ArrayRef dimSizes = getDimSizes(); - uint8_t rank = getRank(); + uint64_t rank = getRank(); if (rank == 0) return emitOpError("rank of cluster is expected to be a positive integer"); @@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() { LogicalResult MeshShardingAttr::verify(function_ref emitError, - SymbolRefAttr, ArrayRef splitAxes, - ArrayRef partialAxes, Partial) { + SymbolRefAttr, ArrayRef splitAxes, + ArrayRef partialAxes, Partial) { // TODO: At present cluster symbol ref is not verified. This is due to the // difficulty in fetching the corresponding symbol op based on an attribute. - llvm::SmallSet visitedAxes; + llvm::SmallSet visitedAxes; - auto checkMeshAxis = [&](ArrayRef axesArray) -> LogicalResult { - for (int8_t axis : axesArray) { + auto checkMeshAxis = [&](ArrayRef axesArray) -> LogicalResult { + for (int32_t axis : axesArray) { if (axis < 0) return emitError() << "mesh axis is expected to be non-negative"; if (!visitedAxes.insert(axis).second) @@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref emitError, return success(); }; - for (DenseI8ArrayAttr subAxes : splitAxes) { - ArrayRef subAxesArray = subAxes.asArrayRef(); + for (DenseI32ArrayAttr subAxes : splitAxes) { + ArrayRef subAxesArray = subAxes.asArrayRef(); if (failed(checkMeshAxis(subAxesArray))) return failure(); }