diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index 94d230e8f82..b51c2567f0a 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -114,44 +114,106 @@ Error register_kernels(const Span kernels) { } namespace { -int copy_char_as_number_to_buf(char num, char* buf) { - if ((char)num < 10) { +/** + * Writes `num` as a decimal string to `buf` and returns the number of bytes + * written. Returns -1 if `buf` is too small or if `num` is not supported. + */ +int copy_char_as_number_to_buf(int num, char* buf, size_t buf_size) { + if (num < 0) { + return -1; + } + if (num < 10) { + if (buf_size < 1) { + return -1; + } *buf = '0' + (char)num; - buf += 1; return 1; - } else { - *buf = '0' + ((char)num) / 10; - buf += 1; + } + if (num < 100) { + if (buf_size < 2) { + return -1; + } + *buf++ = '0' + ((char)num) / 10; *buf = '0' + ((char)num) % 10; - buf += 1; return 2; } + return -1; } } // namespace namespace internal { -void make_kernel_key_string(Span key, char* buf) { +Error make_kernel_key_string( + Span key, + char* buf, + size_t buf_size) { if (key.empty()) { - // If no tensor is present in an op, kernel key does not apply - return; + // If no tensor is present in an op, kernel key does not apply. + if (buf_size > 0) { + buf[0] = '\0'; + } + return Error::Ok; } - strncpy(buf, "v1/", 3); + + // Reserve one byte for null terminator. + if (buf_size < 1) { + return Error::InvalidArgument; + } + buf_size -= 1; + + // Add prefix. + if (buf_size < 3) { + return Error::InvalidArgument; + } + memcpy(buf, "v1/", 3); buf += 3; + buf_size -= 3; + + // Add tensor meta. for (size_t i = 0; i < key.size(); i++) { auto& meta = key[i]; - buf += copy_char_as_number_to_buf((char)meta.dtype_, buf); - *buf = ';'; - buf += 1; + + // Add dtype. + int n = copy_char_as_number_to_buf((int)meta.dtype_, buf, buf_size); + if (n < 0) { + return Error::InvalidArgument; + } + buf += n; + buf_size -= n; + + // Add separator between dtype and dim order. + if (buf_size < 1) { + return Error::InvalidArgument; + } + *buf++ = ';'; + buf_size -= 1; + + // Add dim order. for (int j = 0; j < meta.dim_order_.size(); j++) { - buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf); - if (j != meta.dim_order_.size() - 1) { - *buf = ','; - buf += 1; + n = copy_char_as_number_to_buf((int)meta.dim_order_[j], buf, buf_size); + if (n < 0) { + return Error::InvalidArgument; + } + buf += n; + buf_size -= n; + + if (j < meta.dim_order_.size() - 1) { + if (buf_size < 1) { + return Error::InvalidArgument; + } + *buf++ = ','; + buf_size -= 1; + } + } + if (i < key.size() - 1) { + if (buf_size < 1) { + return Error::InvalidArgument; } + *buf++ = '|'; + buf_size -= 1; } - *buf = (i < (key.size() - 1)) ? '|' : 0x00; - buf += 1; } + *buf = '\0'; // Space for this was reserved above. + return Error::Ok; } } // namespace internal @@ -164,10 +226,14 @@ bool registry_has_op_function( Result get_op_function_from_registry( const char* name, Span meta_list) { - // @lint-ignore CLANGTIDY facebook-hte-CArray - char buf[KernelKey::MAX_SIZE] = {0}; - internal::make_kernel_key_string(meta_list, buf); - KernelKey kernel_key = KernelKey(buf); + std::array key_string; + Error err = internal::make_kernel_key_string( + meta_list, key_string.data(), key_string.size()); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to make kernel key string"); + return err; + } + KernelKey kernel_key = KernelKey(key_string.data()); int32_t fallback_idx = -1; for (size_t idx = 0; idx < num_registered_kernels; idx++) { diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index e4c5d6706e2..82815852e6f 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -96,25 +96,21 @@ struct TensorMeta { /** * Describes which dtype & dim order specialized kernel to be bound to an - * operator. If `is_fallback_` is true, it means this kernel can be used as a - * fallback, if false, it means this kernel can only be used if all the - * `TensorMeta` are matched. Fallback means this kernel will be used for - * all input tensor dtypes and dim orders, if the specialized kernel is not - * registered. + * operator. * - * The format of a kernel key data is a string: - * "v/|..." - * Size: Up to 691 1 1 1 (42 +1) * 16 - * Assuming max number of tensors is 16 ^ - * Kernel key version is v1 for now. If the kernel key format changes, - * update the version to avoid breaking pre-existing kernel keys. - * Example: v1/7;0,1,2,3 - * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3 + * Kernel key data is a string with the format: + * + * "v/|..." + * + * The version is v1 for now. If the kernel key format changes, update the + * version to avoid breaking pre-existing kernel keys. * * Each tensor_meta has the following format: ";" - * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2 - * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example: - * 7;0,1,2,3 for [double; 0, 1, 2, 3] + * + * Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7" + * + * This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the + * second with dtype=1 and dim order 0,1,2,3,4,5,6,7. * * IMPORTANT: * Users should not construct a kernel key manually. Instead, it should be @@ -122,13 +118,21 @@ struct TensorMeta { */ struct KernelKey { public: + /** + * Creates a fallback (non-specialized) kernel key: this kernel can be used + * for all input tensor dtypes and dim orders if the specialized kernel is not + * registered. + */ KernelKey() : is_fallback_(true) {} + /** + * Creates a specialized (non-fallback) kernel key that matches a specific + * set of input tensor dtypes and dim orders. See the class comment for the + * expected format of `kernel_key_data`. + */ /* implicit */ KernelKey(const char* kernel_key_data) : kernel_key_data_(kernel_key_data), is_fallback_(false) {} - constexpr static int MAX_SIZE = 691; - bool operator==(const KernelKey& other) const { return this->equals(other); } @@ -144,7 +148,7 @@ struct KernelKey { if (is_fallback_) { return true; } - return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0; + return strcmp(kernel_key_data_, other.kernel_key_data_) == 0; } bool is_fallback() const { @@ -194,7 +198,23 @@ struct Kernel { }; namespace internal { -void make_kernel_key_string(Span key, char* buf); + +/** + * A make_kernel_key_string buffer size that is large enough to hold a kernel + * key string with 16 tensors of 16 dimensions, plus the trailing NUL byte. + */ +constexpr size_t kKernelKeyBufSize = 659; + +/** + * Given the list of input tensor dtypes + dim orders, writes the kernel key + * string into the buffer. Returns an error if the buffer is too small or if the + * tensors cannot be represented as a valid key string. + */ +Error make_kernel_key_string( + Span key, + char* buf, + size_t buf_size); + } // namespace internal /** diff --git a/runtime/kernel/test/operator_registry_test.cpp b/runtime/kernel/test/operator_registry_test.cpp index 15104609b92..76c2e8e0930 100644 --- a/runtime/kernel/test/operator_registry_test.cpp +++ b/runtime/kernel/test/operator_registry_test.cpp @@ -34,8 +34,147 @@ using executorch::runtime::registry_has_op_function; using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::TensorMeta; +using executorch::runtime::internal::kKernelKeyBufSize; using executorch::runtime::testing::make_kernel_key; +// +// Tests for make_kernel_key_string +// + +// Helper for testing make_kernel_key_string. +void test_make_kernel_key_string( + const std::vector>>& tensors, + const char* expected_key) { + const size_t min_buf_size = strlen(expected_key) + 1; + + // Sweep across too-small buffer sizes, exercising all possible failure + // checks. Rely on ASAN to detect buffer overflows. + for (size_t buf_size = 0; buf_size < min_buf_size; buf_size++) { + std::vector actual_key(buf_size, 0x55); + Error err = make_kernel_key( + tensors, + // nullptr should be valid for buf_size == 0 because it won't be written + // to. + buf_size == 0 ? nullptr : actual_key.data(), + actual_key.size()); + EXPECT_NE(err, Error::Ok); + } + + // Demonstrate that it succeeds for buffers of exactly the right size or + // larger. + for (size_t buf_size = min_buf_size; buf_size < min_buf_size + 1; + buf_size++) { + std::vector actual_key(buf_size, 0x55); + Error err = make_kernel_key(tensors, actual_key.data(), actual_key.size()); + ASSERT_EQ(err, Error::Ok); + EXPECT_STREQ(actual_key.data(), expected_key); + } +} + +TEST(MakeKernelKeyStringTest, ZeroTensorSuccessWithNullBuffer) { + Error err = make_kernel_key({}, nullptr, 0); + EXPECT_EQ(err, Error::Ok); +} + +TEST(MakeKernelKeyStringTest, ZeroTensorSuccessMakesEmptyString) { + char buf = 0x55; + Error err = make_kernel_key({}, &buf, 1); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(buf, '\0'); +} + +TEST(MakeKernelKeyStringTest, OneTensorSuccess) { + test_make_kernel_key_string( + {{ScalarType::Long, {0, 1, 2, 3}}}, "v1/4;0,1,2,3"); +} + +TEST(MakeKernelKeyStringTest, TwoTensorSuccess) { + test_make_kernel_key_string( + {{ScalarType::Long, {0, 1, 2, 3}}, {ScalarType::Double, {3, 2, 1, 0}}}, + "v1/4;0,1,2,3|7;3,2,1,0"); +} + +TEST(MakeKernelKeyStringTest, ThreeTensorSuccess) { + test_make_kernel_key_string( + {{ScalarType::Long, {0, 1, 2, 3}}, + {ScalarType::Double, {3, 2, 1, 0}}, + {ScalarType::Byte, {2, 1, 3, 0}}}, + "v1/4;0,1,2,3|7;3,2,1,0|0;2,1,3,0"); +} + +TEST(MakeKernelKeyStringTest, TwoDigitDimOrderSuccess) { + test_make_kernel_key_string( + {{ScalarType::Long, {0, 10, 2, 99}}}, "v1/4;0,10,2,99"); +} + +TEST(MakeKernelKeyStringTest, ThreeDigitDimOrderFailure) { + std::vector actual_key(1024, 0x55); // Large enough for any key. + Error err = make_kernel_key( + // Cannot represent a dim order entry with more than two digits. + {{ScalarType::Long, {0, 100, 2, 255}}}, + actual_key.data(), + actual_key.size()); + EXPECT_NE(err, Error::Ok); +} + +TEST(MakeKernelKeyStringTest, NegativeScalarTypeFailure) { + std::vector actual_key(1024, 0x55); // Large enough for any key. + Error err = make_kernel_key( + // Cannot represent a ScalarType (aka int8_t) with a negative value. + {{(ScalarType)-1, {0, 1, 2, 3}}}, + actual_key.data(), + actual_key.size()); + EXPECT_NE(err, Error::Ok); +} + +TEST(MakeKernelKeyStringTest, KeyBufSizeMeetsAssumptions) { + // Create the longest key that fits in the assupmtions of kKernelKeyBufSize: + // 16 tensors, 16 dims, with two-digit ScalarTypes. + std::vector>> + tensors; + tensors.reserve(16); + for (int i = 0; i < 16; i++) { + std::vector dims; + dims.reserve(16); + for (int j = 0; j < 16; j++) { + dims.emplace_back(j); + } + tensors.emplace_back((ScalarType)10, dims); + } + + std::vector actual_key(kKernelKeyBufSize, 0x55); + Error err = make_kernel_key(tensors, actual_key.data(), actual_key.size()); + ASSERT_EQ(err, Error::Ok); + EXPECT_STREQ( + actual_key.data(), + "v1/" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15|" + "10;0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15"); + EXPECT_LE(strlen(actual_key.data()) + 1, kKernelKeyBufSize); +} + +// +// Tests for public operator registry APIs +// + class OperatorRegistryTest : public ::testing::Test { public: void SetUp() override { @@ -46,7 +185,8 @@ class OperatorRegistryTest : public ::testing::Test { TEST_F(OperatorRegistryTest, Basic) { Kernel kernels[] = {Kernel("foo", [](KernelRuntimeContext&, EValue**) {})}; Span kernels_span(kernels); - (void)register_kernels(kernels_span); + Error err = register_kernels(kernels_span); + ASSERT_EQ(err, Error::Ok); EXPECT_FALSE(registry_has_op_function("fpp")); EXPECT_TRUE(registry_has_op_function("foo")); } @@ -59,12 +199,14 @@ TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) { ET_EXPECT_DEATH({ (void)register_kernels(kernels_span); }, ""); } -constexpr int BUF_SIZE = KernelKey::MAX_SIZE; - TEST_F(OperatorRegistryTest, KernelKeyEquals) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey long_contiguous = KernelKey(buf_long_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey long_contiguous = KernelKey(buf_long_contiguous.data()); KernelKey long_key_1 = KernelKey(long_contiguous); @@ -72,31 +214,73 @@ TEST_F(OperatorRegistryTest, KernelKeyEquals) { EXPECT_EQ(long_key_1, long_key_2); - char buf_float_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous); - KernelKey float_key = KernelKey(buf_float_contiguous); + std::array buf_float_contiguous; + err = make_kernel_key( + {{ScalarType::Float, {0, 1, 2, 3}}}, + buf_float_contiguous.data(), + buf_float_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey float_key = KernelKey(buf_float_contiguous.data()); EXPECT_NE(long_key_1, float_key); - char buf_channel_first[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 3, 1, 2}}}, buf_channel_first); - KernelKey long_key_3 = KernelKey(buf_channel_first); + std::array buf_channel_first; + err = make_kernel_key( + {{ScalarType::Long, {0, 3, 1, 2}}}, + buf_channel_first.data(), + buf_channel_first.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey long_key_3 = KernelKey(buf_channel_first.data()); EXPECT_NE(long_key_1, long_key_3); } +TEST_F(OperatorRegistryTest, GetOpFailsForLongKernelKey) { + // Looking up a way-too-long kernel key should fail with an error. + std::vector>> + tensors; + // 1000 is a lot of tensors. + tensors.reserve(1000); + for (int i = 0; i < 1000; i++) { + std::vector dims; + dims.reserve(16); + for (int j = 0; j < 16; j++) { + dims.emplace_back(j); + } + tensors.emplace_back((ScalarType)10, dims); + } + std::vector meta; + for (auto& t : tensors) { + Span dim_order( + t.second.data(), t.second.size()); + meta.emplace_back(t.first, dim_order); + } + Span metadata(meta.data(), meta.size()); + + auto op = get_op_function_from_registry("test::not-real", metadata); + EXPECT_NE(op.error(), Error::Ok); + EXPECT_NE(op.error(), Error::OperatorMissing); + // The lookup failed, but not because the operator is missing. +} + TEST_F(OperatorRegistryTest, RegisterKernels) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey key = KernelKey(buf_long_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key = KernelKey(buf_long_contiguous.data()); Kernel kernel_1 = Kernel( "test::boo", key, [](KernelRuntimeContext& context, EValue** stack) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({&kernel_1, 1}); - EXPECT_EQ(s1, Error::Ok); + err = register_kernels({&kernel_1, 1}); + ASSERT_EQ(err, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; auto dim_order_type = Span(dims, 4); @@ -126,13 +310,21 @@ TEST_F(OperatorRegistryTest, RegisterKernels) { } TEST_F(OperatorRegistryTest, RegisterTwoKernels) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey key_1 = KernelKey(buf_long_contiguous); - - char buf_float_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous); - KernelKey key_2 = KernelKey(buf_float_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key_1 = KernelKey(buf_long_contiguous.data()); + + std::array buf_float_contiguous; + err = make_kernel_key( + {{ScalarType::Float, {0, 1, 2, 3}}}, + buf_float_contiguous.data(), + buf_float_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key_2 = KernelKey(buf_float_contiguous.data()); Kernel kernel_1 = Kernel( "test::bar", key_1, [](KernelRuntimeContext& context, EValue** stack) { (void)context; @@ -144,7 +336,9 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { *(stack[0]) = Scalar(50); }); Kernel kernels[] = {kernel_1, kernel_2}; - auto s1 = register_kernels(kernels); + err = register_kernels(kernels); + ASSERT_EQ(err, Error::Ok); + // has both kernels Tensor::DimOrderType dims[] = {0, 1, 2, 3}; auto dim_order_type = Span(dims, 4); @@ -189,9 +383,13 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { } TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey key = KernelKey(buf_long_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key = KernelKey(buf_long_contiguous.data()); Kernel kernel_1 = Kernel( "test::baz", key, [](KernelRuntimeContext& context, EValue** stack) { @@ -205,22 +403,26 @@ TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) { }); Kernel kernels[] = {kernel_1, kernel_2}; // clang-tidy off - ET_EXPECT_DEATH({ auto s1 = register_kernels(kernels); }, ""); + ET_EXPECT_DEATH({ (void)register_kernels(kernels); }, ""); // clang-tidy on } TEST_F(OperatorRegistryTest, ExecutorChecksKernel) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey key = KernelKey(buf_long_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key = KernelKey(buf_long_contiguous.data()); Kernel kernel_1 = Kernel( "test::qux", key, [](KernelRuntimeContext& context, EValue** stack) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({&kernel_1, 1}); - EXPECT_EQ(s1, Error::Ok); + err = register_kernels({&kernel_1, 1}); + ASSERT_EQ(err, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; auto dim_order_type = Span(dims, 4); @@ -242,17 +444,21 @@ TEST_F(OperatorRegistryTest, ExecutorChecksKernel) { } TEST_F(OperatorRegistryTest, ExecutorUsesKernel) { - char buf_long_contiguous[BUF_SIZE]; - make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous); - KernelKey key = KernelKey(buf_long_contiguous); + std::array buf_long_contiguous; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, + buf_long_contiguous.data(), + buf_long_contiguous.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey key = KernelKey(buf_long_contiguous.data()); Kernel kernel_1 = Kernel( "test::quux", key, [](KernelRuntimeContext& context, EValue** stack) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({&kernel_1, 1}); - EXPECT_EQ(s1, Error::Ok); + err = register_kernels({&kernel_1, 1}); + ASSERT_EQ(err, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; auto dim_order_type = Span(dims, 4); @@ -283,8 +489,8 @@ TEST_F(OperatorRegistryTest, ExecutorUsesFallbackKernel) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({&kernel_1, 1}); - EXPECT_EQ(s1, Error::Ok); + Error err = register_kernels({&kernel_1, 1}); + EXPECT_EQ(err, Error::Ok); EXPECT_TRUE(registry_has_op_function("test::corge")); EXPECT_TRUE(registry_has_op_function("test::corge", {})); diff --git a/runtime/kernel/test/test_util.h b/runtime/kernel/test/test_util.h index 082635bd0e4..be77df1fd0c 100644 --- a/runtime/kernel/test/test_util.h +++ b/runtime/kernel/test/test_util.h @@ -18,19 +18,20 @@ namespace runtime { namespace testing { -inline void make_kernel_key( - std::vector>> tensors, - char* buf) { + std::vector>>& tensors, + char* buf, + size_t buf_size) { std::vector meta; for (auto& t : tensors) { Span dim_order( - t.second.data(), t.second.size()); + const_cast(t.second.data()), t.second.size()); meta.emplace_back(t.first, dim_order); } Span metadata(meta.data(), meta.size()); - internal::make_kernel_key_string(metadata, buf); + return internal::make_kernel_key_string(metadata, buf, buf_size); } } // namespace testing