Skip to content

Commit b6fbd2c

Browse files
dbortfacebook-github-bot
authored andcommitted
Add safety checks when rendering kernel key strings
Summary: The old code assumed that it was handed a MAX_SIZE buffer, and that the list of TensorMeta values would never generate a string longer than that size. This PR adds explicit size tracking and an error code to the API, and now returns an error if the buffer is too small for the provided values. While I'm here, move MAX_SIZE out of the public API, since it's not an intrinsic aspect of kernel keys. This is technically a BC-breaking change, but I don't expect that any users are actually depending on it. Add unit tests for all modified code. Differential Revision: D69324821
1 parent d99970b commit b6fbd2c

File tree

4 files changed

+356
-68
lines changed

4 files changed

+356
-68
lines changed

runtime/kernel/operator_registry.cpp

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -114,44 +114,106 @@ Error register_kernels(const Span<const Kernel> kernels) {
114114
}
115115

116116
namespace {
117-
int copy_char_as_number_to_buf(char num, char* buf) {
118-
if ((char)num < 10) {
117+
/**
118+
* Writes `num` as a decimal string to `buf` and returns the number of bytes
119+
* written. Returns -1 if `buf` is too small or if `num` is not supported.
120+
*/
121+
int copy_char_as_number_to_buf(int num, char* buf, size_t buf_size) {
122+
if (num < 0) {
123+
return -1;
124+
}
125+
if (num < 10) {
126+
if (buf_size < 1) {
127+
return -1;
128+
}
119129
*buf = '0' + (char)num;
120-
buf += 1;
121130
return 1;
122-
} else {
123-
*buf = '0' + ((char)num) / 10;
124-
buf += 1;
131+
}
132+
if (num < 100) {
133+
if (buf_size < 2) {
134+
return -1;
135+
}
136+
*buf++ = '0' + ((char)num) / 10;
125137
*buf = '0' + ((char)num) % 10;
126-
buf += 1;
127138
return 2;
128139
}
140+
return -1;
129141
}
130142
} // namespace
131143

132144
namespace internal {
133-
void make_kernel_key_string(Span<const TensorMeta> key, char* buf) {
145+
Error make_kernel_key_string(
146+
Span<const TensorMeta> key,
147+
char* buf,
148+
size_t buf_size) {
134149
if (key.empty()) {
135-
// If no tensor is present in an op, kernel key does not apply
136-
return;
150+
// If no tensor is present in an op, kernel key does not apply.
151+
if (buf_size > 0) {
152+
buf[0] = '\0';
153+
}
154+
return Error::Ok;
137155
}
138-
strncpy(buf, "v1/", 3);
156+
157+
// Reserve one byte for null terminator.
158+
if (buf_size < 1) {
159+
return Error::InvalidArgument;
160+
}
161+
buf_size -= 1;
162+
163+
// Add prefix.
164+
if (buf_size < 3) {
165+
return Error::InvalidArgument;
166+
}
167+
memcpy(buf, "v1/", 3);
139168
buf += 3;
169+
buf_size -= 3;
170+
171+
// Add tensor meta.
140172
for (size_t i = 0; i < key.size(); i++) {
141173
auto& meta = key[i];
142-
buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
143-
*buf = ';';
144-
buf += 1;
174+
175+
// Add dtype.
176+
int n = copy_char_as_number_to_buf((int)meta.dtype_, buf, buf_size);
177+
if (n < 0) {
178+
return Error::InvalidArgument;
179+
}
180+
buf += n;
181+
buf_size -= n;
182+
183+
// Add separator between dtype and dim order.
184+
if (buf_size < 1) {
185+
return Error::InvalidArgument;
186+
}
187+
*buf++ = ';';
188+
buf_size -= 1;
189+
190+
// Add dim order.
145191
for (int j = 0; j < meta.dim_order_.size(); j++) {
146-
buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
147-
if (j != meta.dim_order_.size() - 1) {
148-
*buf = ',';
149-
buf += 1;
192+
n = copy_char_as_number_to_buf((int)meta.dim_order_[j], buf, buf_size);
193+
if (n < 0) {
194+
return Error::InvalidArgument;
195+
}
196+
buf += n;
197+
buf_size -= n;
198+
199+
if (j < meta.dim_order_.size() - 1) {
200+
if (buf_size < 1) {
201+
return Error::InvalidArgument;
202+
}
203+
*buf++ = ',';
204+
buf_size -= 1;
205+
}
206+
}
207+
if (i < key.size() - 1) {
208+
if (buf_size < 1) {
209+
return Error::InvalidArgument;
150210
}
211+
*buf++ = '|';
212+
buf_size -= 1;
151213
}
152-
*buf = (i < (key.size() - 1)) ? '|' : 0x00;
153-
buf += 1;
154214
}
215+
*buf = '\0'; // Space for this was reserved above.
216+
return Error::Ok;
155217
}
156218
} // namespace internal
157219

@@ -165,8 +227,12 @@ Result<OpFunction> get_op_function_from_registry(
165227
const char* name,
166228
Span<const TensorMeta> meta_list) {
167229
// @lint-ignore CLANGTIDY facebook-hte-CArray
168-
char buf[KernelKey::MAX_SIZE] = {0};
169-
internal::make_kernel_key_string(meta_list, buf);
230+
char buf[internal::kKernelKeyBufSize];
231+
Error err = internal::make_kernel_key_string(meta_list, buf, sizeof(buf));
232+
if (err != Error::Ok) {
233+
ET_LOG(Error, "Failed to make kernel key string");
234+
return err;
235+
}
170236
KernelKey kernel_key = KernelKey(buf);
171237

172238
int32_t fallback_idx = -1;

runtime/kernel/operator_registry.h

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,39 +96,43 @@ struct TensorMeta {
9696

9797
/**
9898
* Describes which dtype & dim order specialized kernel to be bound to an
99-
* operator. If `is_fallback_` is true, it means this kernel can be used as a
100-
* fallback, if false, it means this kernel can only be used if all the
101-
* `TensorMeta` are matched. Fallback means this kernel will be used for
102-
* all input tensor dtypes and dim orders, if the specialized kernel is not
103-
* registered.
99+
* operator.
104100
*
105-
* The format of a kernel key data is a string:
106-
* "v<version>/<tensor_meta>|<tensor_meta>..."
107-
* Size: Up to 691 1 1 1 (42 +1) * 16
108-
* Assuming max number of tensors is 16 ^
109-
* Kernel key version is v1 for now. If the kernel key format changes,
110-
* update the version to avoid breaking pre-existing kernel keys.
111-
* Example: v1/7;0,1,2,3
112-
* The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
101+
* Kernel key data is a string with the format:
102+
*
103+
* "v<version>/<tensor_meta>|<tensor_meta>..."
104+
*
105+
* The version is v1 for now. If the kernel key format changes, update the
106+
* version to avoid breaking pre-existing kernel keys.
113107
*
114108
* Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115-
* Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
116-
* for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117-
* 7;0,1,2,3 for [double; 0, 1, 2, 3]
109+
*
110+
* Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7"
111+
*
112+
* This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the
113+
* second with dtype=1 and dim order 0,1,2,3,4,5,6,7.
118114
*
119115
* IMPORTANT:
120116
* Users should not construct a kernel key manually. Instead, it should be
121117
* generated from kernel yaml.
122118
*/
123119
struct KernelKey {
124120
public:
121+
/**
122+
* Creates a fallback (non-specialized) kernel key: this kernel can be used
123+
* for all input tensor dtypes and dim orders if the specialized kernel is not
124+
* registered.
125+
*/
125126
KernelKey() : is_fallback_(true) {}
126127

128+
/**
129+
* Creates a specialized (non-fallback) kernel key that matches a specific
130+
* set of input tensor dtypes and dim orders. See the class comment for the
131+
* expected format of `kernel_key_data`.
132+
*/
127133
/* implicit */ KernelKey(const char* kernel_key_data)
128134
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
129135

130-
constexpr static int MAX_SIZE = 691;
131-
132136
bool operator==(const KernelKey& other) const {
133137
return this->equals(other);
134138
}
@@ -144,7 +148,7 @@ struct KernelKey {
144148
if (is_fallback_) {
145149
return true;
146150
}
147-
return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
151+
return strcmp(kernel_key_data_, other.kernel_key_data_) == 0;
148152
}
149153

150154
bool is_fallback() const {
@@ -194,7 +198,23 @@ struct Kernel {
194198
};
195199

196200
namespace internal {
197-
void make_kernel_key_string(Span<const TensorMeta> key, char* buf);
201+
202+
/**
203+
* A make_kernel_key_string buffer size that is large enough to hold a kernel
204+
* key string with 16 tensors of 16 dimensions, plus the trailing NUL byte.
205+
*/
206+
constexpr size_t kKernelKeyBufSize = 659;
207+
208+
/**
209+
* Given the list of input tensor dtypes + dim orders, writes the kernel key
210+
* string into the buffer. Returns an error if the buffer is too small or if the
211+
* tensors cannot be represented as a valid key string.
212+
*/
213+
Error make_kernel_key_string(
214+
Span<const TensorMeta> key,
215+
char* buf,
216+
size_t buf_size);
217+
198218
} // namespace internal
199219

200220
/**

0 commit comments

Comments
 (0)