diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 25209ce449745..9d46c5b3bdc5e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -303,8 +303,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, indices.push_back( builder.getInt32(valueOrAttr.get().getInt())); } - Type baseElementType = op.getSourceElementType(); - llvm::Type *elementType = moduleTranslation.convertType(baseElementType); + + Type elemTypeFromAttr = op.getSourceElementType(); + auto ptrType = ::llvm::cast(op.getType()); + Type elemTypeFromPtrType = ptrType.getElementType(); + + llvm::Type *elementType = moduleTranslation.convertType( + elemTypeFromAttr ? elemTypeFromAttr : elemTypeFromPtrType); $res = builder.CreateGEP(elementType, $base, indices, "", $inbounds); }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 95c04098d05fc..62cb595069e66 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -287,14 +287,17 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { } /// Checks that the elemental type is present in either the pointer type or -/// the attribute, but not both. +/// the attribute, but not in none or both. static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, std::optional ptrElementType) { - if (ptrType.isOpaque() && !ptrElementType.has_value()) { + bool typePresentInPointerType = !ptrType.isOpaque(); + bool typePresentInAttribute = ptrElementType.has_value(); + + if (!typePresentInPointerType && !typePresentInAttribute) { return op->emitOpError() << "expected '" << kElemTypeAttrName << "' attribute if opaque pointer type is used"; } - if (!ptrType.isOpaque() && ptrElementType.has_value()) { + if (typePresentInPointerType && typePresentInAttribute) { return op->emitOpError() << "unexpected '" << kElemTypeAttrName << "' attribute when non-opaque pointer type is used"; diff --git a/mlir/test/Target/LLVMIR/opaque-ptr.mlir b/mlir/test/Target/LLVMIR/opaque-ptr.mlir index c21f9b0542deb..3bde192b4cc4d 100644 --- a/mlir/test/Target/LLVMIR/opaque-ptr.mlir +++ b/mlir/test/Target/LLVMIR/opaque-ptr.mlir @@ -42,6 +42,15 @@ llvm.func @opaque_ptr_gep_struct(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr { llvm.return %0 : !llvm.ptr } +// CHECK-LABEL: @opaque_ptr_elem_type +llvm.func @opaque_ptr_elem_type(%0: !llvm.ptr) -> !llvm.ptr { + // CHECK: getelementptr ptr, ptr + %1 = llvm.getelementptr %0[0] { elem_type = !llvm.ptr } : (!llvm.ptr) -> !llvm.ptr + // CHECK: getelementptr ptr, ptr + %2 = llvm.getelementptr %0[0] : (!llvm.ptr) -> !llvm.ptr + llvm.return %1 : !llvm.ptr +} + // CHECK-LABEL: @opaque_ptr_matrix_load_store llvm.func @opaque_ptr_matrix_load_store(%ptr: !llvm.ptr, %stride: i64) -> vector<48 x f32> { // CHECK: call <48 x float> @llvm.matrix.column.major.load.v48f32.i64