Skip to content

Commit 3be3883

Browse files
authored
[mlir][VectorOps] Support string literals in vector.print (#68695)
Printing strings within integration tests is currently quite annoyingly verbose, and can't be tucked into shared helpers as the types depend on the length of the string: ``` llvm.mlir.global internal constant @hello_world("Hello, World!\0") func.func @entry() { %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>> %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.getelementptr %0[%1, %1] : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8> llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> () return } ``` So this patch adds a simple extension to `vector.print` to simplify this: ``` func.func @entry() { // Print a vector of characters ;) vector.print str "Hello, World!" return } ``` Most of the logic for this is now shared with `cf.assert` which already does something similar. Depends on #68694
1 parent 1072fcd commit 3be3883

File tree

14 files changed

+204
-58
lines changed

14 files changed

+204
-58
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
10+
#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
11+
12+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "llvm/ADT/StringRef.h"
14+
#include <optional>
15+
16+
namespace mlir {
17+
18+
class OpBuilder;
19+
class LLVMTypeConverter;
20+
21+
namespace LLVM {
22+
23+
/// Generate IR that prints the given string to stdout.
24+
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
25+
/// have the signature void(char const*). The default function is `printString`.
26+
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
27+
StringRef symbolName, StringRef string,
28+
const LLVMTypeConverter &typeConverter,
29+
bool addNewline = true,
30+
std::optional<StringRef> runtimeFunctionName = {});
31+
} // namespace LLVM
32+
33+
} // namespace mlir
34+
35+
#endif

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Support/LLVM.h"
19+
#include <optional>
1920

2021
namespace mlir {
2122
class Location;
@@ -38,8 +39,12 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
3839
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
3940
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
4041
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,
42-
bool opaquePointers);
42+
/// Declares a function to print a C-string.
43+
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
44+
/// have the signature void(char const*). The default function is `printString`.
45+
LLVM::LLVMFuncOp
46+
lookupOrCreatePrintStringFn(ModuleOp moduleOp, bool opaquePointers,
47+
std::optional<StringRef> runtimeFunctionName = {});
4348
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
4449
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
4550
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

+33-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2626
include "mlir/Interfaces/SideEffectInterfaces.td"
2727
include "mlir/Interfaces/VectorInterfaces.td"
2828
include "mlir/Interfaces/ViewLikeInterface.td"
29+
include "mlir/IR/BuiltinAttributes.td"
2930

3031
// TODO: Add an attribute to specify a different algebra with operators other
3132
// than the current set: {*, +}.
@@ -2477,12 +2478,18 @@ def Vector_TransposeOp :
24772478
}
24782479

24792480
def Vector_PrintOp :
2480-
Vector_Op<"print", []>,
2481+
Vector_Op<"print", [
2482+
PredOpTrait<
2483+
"`source` or `punctuation` are not set when printing strings",
2484+
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
2485+
>,
2486+
]>,
24812487
Arguments<(ins Optional<Type<Or<[
24822488
AnyVectorOfAnyRank.predicate,
24832489
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
24842490
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
2485-
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
2491+
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
2492+
OptionalAttr<Builtin_StringAttr>:$stringLiteral)
24862493
> {
24872494
let summary = "print operation (for testing and debugging)";
24882495
let description = [{
@@ -2521,6 +2528,13 @@ def Vector_PrintOp :
25212528
```mlir
25222529
vector.print punctuation <newline>
25232530
```
2531+
2532+
Additionally, to aid with debugging and testing `vector.print` can also
2533+
print constant strings:
2534+
2535+
```mlir
2536+
vector.print str "Hello, World!"
2537+
```
25242538
}];
25252539
let extraClassDeclaration = [{
25262540
Type getPrintType() {
@@ -2529,11 +2543,26 @@ def Vector_PrintOp :
25292543
}];
25302544
let builders = [
25312545
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
2532-
build($_builder, $_state, {}, punctuation);
2546+
build($_builder, $_state, {}, punctuation, {});
2547+
}]>,
2548+
OpBuilder<(ins "::mlir::Value":$source), [{
2549+
build($_builder, $_state, source, PrintPunctuation::NewLine);
2550+
}]>,
2551+
OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
2552+
build($_builder, $_state, source, punctuation, {});
2553+
}]>,
2554+
OpBuilder<(ins "::llvm::StringRef":$string), [{
2555+
build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
25332556
}]>,
25342557
];
25352558

2536-
let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
2559+
let assemblyFormat = [{
2560+
($source^ `:` type($source))?
2561+
oilist(
2562+
`str` $stringLiteral
2563+
| `punctuation` $punctuation)
2564+
attr-dict
2565+
}];
25372566
}
25382567

25392568
//===----------------------------------------------------------------------===//

mlir/include/mlir/ExecutionEngine/CRunnerUtils.h

+1
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
465465
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
466466
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
467467
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
468+
extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s);
468469
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
469470
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
470471
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

+4-46
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1717
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1818
#include "mlir/Conversion/LLVMCommon/Pattern.h"
19+
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
1920
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
2021
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2122
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
3637

3738
#define PASS_NAME "convert-cf-to-llvm"
3839

39-
static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
40-
std::string prefix = "assert_msg_";
41-
int counter = 0;
42-
while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
43-
++counter;
44-
return prefix + std::to_string(counter);
45-
}
46-
47-
/// Generate IR that prints the given string to stderr.
48-
static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
49-
StringRef msg,
50-
const LLVMTypeConverter &typeConverter) {
51-
auto ip = builder.saveInsertionPoint();
52-
builder.setInsertionPointToStart(moduleOp.getBody());
53-
MLIRContext *ctx = builder.getContext();
54-
55-
// Create a zero-terminated byte representation and allocate global symbol.
56-
SmallVector<uint8_t> elementVals;
57-
elementVals.append(msg.begin(), msg.end());
58-
elementVals.push_back(0);
59-
auto dataAttrType = RankedTensorType::get(
60-
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
61-
auto dataAttr =
62-
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
63-
auto arrayTy =
64-
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
65-
std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
66-
auto globalOp = builder.create<LLVM::GlobalOp>(
67-
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
68-
dataAttr);
69-
70-
// Emit call to `printStr` in runtime library.
71-
builder.restoreInsertionPoint(ip);
72-
auto msgAddr = builder.create<LLVM::AddressOfOp>(
73-
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
74-
SmallVector<LLVM::GEPArg> indices(1, 0);
75-
Value gep = builder.create<LLVM::GEPOp>(
76-
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
77-
indices);
78-
Operation *printer = LLVM::lookupOrCreatePrintStrFn(
79-
moduleOp, typeConverter.useOpaquePointers());
80-
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
81-
gep);
82-
}
83-
8440
namespace {
8541
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
8642
/// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
10561

10662
// Failed block: Generate IR to print the message and call `abort`.
10763
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
108-
createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
64+
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65+
*getTypeConverter(), /*addNewLine=*/false,
66+
/*runtimeFunctionName=*/"puts");
10967
if (abortOnFailedAssert) {
11068
// Insert the `abort` declaration if necessary.
11169
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

mlir/lib/Conversion/LLVMCommon/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
33
LoweringOptions.cpp
44
MemRefBuilder.cpp
55
Pattern.cpp
6+
PrintCallHelper.cpp
67
StructBuilder.cpp
78
TypeConverter.cpp
89
VectorPattern.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
10+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/IR/Builders.h"
14+
#include "mlir/IR/BuiltinOps.h"
15+
#include "llvm/ADT/ArrayRef.h"
16+
17+
using namespace mlir;
18+
using namespace llvm;
19+
20+
static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
21+
StringRef symbolName) {
22+
static int counter = 0;
23+
std::string uniqueName = std::string(symbolName);
24+
while (moduleOp.lookupSymbol(uniqueName)) {
25+
uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
26+
}
27+
return uniqueName;
28+
}
29+
30+
void mlir::LLVM::createPrintStrCall(
31+
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
32+
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
33+
std::optional<StringRef> runtimeFunctionName) {
34+
auto ip = builder.saveInsertionPoint();
35+
builder.setInsertionPointToStart(moduleOp.getBody());
36+
MLIRContext *ctx = builder.getContext();
37+
38+
// Create a zero-terminated byte representation and allocate global symbol.
39+
SmallVector<uint8_t> elementVals;
40+
elementVals.append(string.begin(), string.end());
41+
if (addNewline)
42+
elementVals.push_back('\n');
43+
elementVals.push_back('\0');
44+
auto dataAttrType = RankedTensorType::get(
45+
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
46+
auto dataAttr =
47+
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
48+
auto arrayTy =
49+
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
50+
auto globalOp = builder.create<LLVM::GlobalOp>(
51+
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
52+
ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
53+
54+
// Emit call to `printStr` in runtime library.
55+
builder.restoreInsertionPoint(ip);
56+
auto msgAddr = builder.create<LLVM::AddressOfOp>(
57+
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
58+
SmallVector<LLVM::GEPArg> indices(1, 0);
59+
Value gep = builder.create<LLVM::GEPOp>(
60+
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
61+
indices);
62+
Operation *printer = LLVM::lookupOrCreatePrintStringFn(
63+
moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName);
64+
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
65+
gep);
66+
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
1010

1111
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12+
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
1213
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1314
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1415
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15481549
}
15491550

15501551
auto punct = printOp.getPunctuation();
1551-
if (punct != PrintPunctuation::NoPunctuation) {
1552+
if (auto stringLiteral = printOp.getStringLiteral()) {
1553+
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1554+
*stringLiteral, *getTypeConverter());
1555+
} else if (punct != PrintPunctuation::NoPunctuation) {
15521556
emitCall(rewriter, printOp->getLoc(), [&] {
15531557
switch (punct) {
15541558
case PrintPunctuation::Close:

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
3030
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
3131
static constexpr llvm::StringRef kPrintF32 = "printF32";
3232
static constexpr llvm::StringRef kPrintF64 = "printF64";
33-
static constexpr llvm::StringRef kPrintStr = "puts";
33+
static constexpr llvm::StringRef kPrintString = "printString";
3434
static constexpr llvm::StringRef kPrintOpen = "printOpen";
3535
static constexpr llvm::StringRef kPrintClose = "printClose";
3636
static constexpr llvm::StringRef kPrintComma = "printComma";
@@ -107,9 +107,10 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
107107
return getCharPtr(context, opaquePointers);
108108
}
109109

110-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp,
111-
bool opaquePointers) {
112-
return lookupOrCreateFn(moduleOp, kPrintStr,
110+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
111+
ModuleOp moduleOp, bool opaquePointers,
112+
std::optional<StringRef> runtimeFunctionName) {
113+
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
113114
getCharPtr(moduleOp->getContext(), opaquePointers),
114115
LLVM::LLVMVoidType::get(moduleOp->getContext()));
115116
}

mlir/lib/ExecutionEngine/CRunnerUtils.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
5252
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
5353
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
5454
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
55+
extern "C" void printString(char const *s) { fputs(s, stdout); }
5556
extern "C" void printOpen() { fputs("( ", stdout); }
5657
extern "C" void printClose() { fputs(" )", stdout); }
5758
extern "C" void printComma() { fputs(", ", stdout); }

mlir/lib/ExecutionEngine/RunnerUtils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) {
158158
_mlir_ciface_printMemrefC64(&descriptor);
159159
}
160160

161-
extern "C" void printCString(char *str) { printf("%s", str); }
161+
/// Deprecated. This should be unified with printString from CRunnerUtils.
162+
extern "C" void printCString(char *str) { fputs(str, stdout); }
162163

163164
extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
164165
impl::printMemRef(*M);

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+14
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
10681068

10691069
// -----
10701070

1071+
// CHECK-LABEL: module {
1072+
// CHECK: llvm.func @printString(!llvm.ptr)
1073+
// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
1074+
// CHECK: @vector_print_string
1075+
// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
1076+
// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
1077+
// CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
1078+
func.func @vector_print_string() {
1079+
vector.print str "Hello, World!"
1080+
return
1081+
}
1082+
1083+
// -----
1084+
10711085
func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
10721086
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
10731087
return %0 : vector<2xf32>

mlir/test/Dialect/Vector/invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
10161016

10171017
// -----
10181018

1019+
func.func @cannot_print_string_with_punctuation_set() {
1020+
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
1021+
vector.print str "Whoops!" punctuation <comma>
1022+
return
1023+
}
1024+
1025+
// -----
1026+
1027+
func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
1028+
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
1029+
vector.print %vec: vector<[4]xf32> str "Yay!"
1030+
return
1031+
}
1032+
1033+
// -----
1034+
10191035
func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
10201036
%c2 = arith.constant 2 : index
10211037
%c3 = arith.constant 3 : index

0 commit comments

Comments
 (0)