Skip to content

Commit 55cb7d3

Browse files
authored
Enable switching to bare pointer ABI for MLIR (#1333)
Once ROCm/rocMLIR#690 lands, the ABI for MLIR-generated kernels will change. This commit prepares MIGraphX for the change by conditionally selecting the new ABI if MLIR reports a sufficiently high API version in its headers.
1 parent 7ecb2de commit 55cb7d3

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/targets/gpu/mlir.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
#include <deque>
4949
#include <variant>
5050

51+
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
52+
#define MIGRAPHX_MLIR_BARE_POINTER
53+
#endif
54+
5155
namespace migraphx {
5256
inline namespace MIGRAPHX_INLINE_NS {
5357
namespace gpu {
@@ -606,9 +610,15 @@ instruction_ref insert_mlir(module& m,
606610
code_object_op co,
607611
const std::vector<instruction_ref>& inputs)
608612
{
613+
609614
std::vector<instruction_ref> refs;
615+
std::size_t last = 0;
616+
#ifdef MIGRAPHX_MLIR_BARE_POINTER
617+
refs.reserve(inputs.size());
618+
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
619+
last = refs.size() - 1;
620+
#else
610621
refs.reserve(inputs.size() * 15);
611-
612622
std::unordered_map<uint64_t, instruction_ref> literal_map{};
613623
auto get_literal = [&](uint64_t value) {
614624
auto fi = literal_map.find(value);
@@ -619,7 +629,6 @@ instruction_ref insert_mlir(module& m,
619629
return lit;
620630
};
621631

622-
std::size_t last = 0;
623632
for(auto input : inputs)
624633
{
625634
const size_t offset = 0;
@@ -643,6 +652,7 @@ instruction_ref insert_mlir(module& m,
643652
[&](const auto& lval) { return get_literal(lval); });
644653
// refs.push_back(get_literal(1)); // G
645654
}
655+
#endif
646656
co.expected_inputs = to_shapes(refs);
647657
co.output_arg = last;
648658
return m.insert_instruction(ins, co, refs);

0 commit comments

Comments
 (0)