Skip to content

Commit b54acab

Browse files
authored
[ET-VK][Ops] choose_qparams.tensor test setup
Differential Revision: D76436894 Pull Request resolved: #11555
1 parent f25bd31 commit b54acab

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed

backends/vulkan/test/op_tests/choose_qparams_test.cpp

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,277 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
114114
} // namespace native
115115
} // namespace executor
116116
} // namespace torch
117+
118+
//
119+
// Reference Implementation
120+
//
121+
122+
/*
123+
* Reference implementation of choose_qparams_tensor
124+
*/
125+
std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
126+
const at::Tensor& input,
127+
int64_t quant_min,
128+
int64_t quant_max) {
129+
// Create output tensors
130+
at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble));
131+
at::Tensor zero_point_out =
132+
at::empty({}, at::device(at::kCPU).dtype(at::kLong));
133+
134+
// Find min and max values in the input tensor
135+
float min_val = input.min().item<float>();
136+
float max_val = input.max().item<float>();
137+
138+
// Extend the [min, max] interval to ensure it contains 0
139+
min_val = std::min(min_val, 0.f);
140+
max_val = std::max(max_val, 0.f);
141+
142+
// Calculate scale
143+
double scale =
144+
(static_cast<double>(max_val) - min_val) / (quant_max - quant_min);
145+
146+
// Handle small scale
147+
constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
148+
if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
149+
scale = 0.1;
150+
}
151+
152+
if (scale < SMALL_SCALE_THRESHOLD) {
153+
float org_scale = scale;
154+
scale = SMALL_SCALE_THRESHOLD;
155+
// Adjust min and max based on new scale
156+
if (min_val == 0.0f) {
157+
max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
158+
} else if (max_val == 0.0f) {
159+
min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
160+
} else {
161+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
162+
min_val *= amplifier;
163+
max_val *= amplifier;
164+
}
165+
}
166+
167+
// Calculate zero point
168+
double zero_point_from_min = quant_min - min_val / static_cast<double>(scale);
169+
double zero_point_from_max = quant_max - max_val / static_cast<double>(scale);
170+
double zero_point_from_min_error =
171+
std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale));
172+
double zero_point_from_max_error =
173+
std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale));
174+
double initial_zero_point =
175+
zero_point_from_min_error < zero_point_from_max_error
176+
? zero_point_from_min
177+
: zero_point_from_max;
178+
179+
// Nudge zero point to be an integer
180+
int64_t nudged_zero_point = 0;
181+
if (initial_zero_point < quant_min) {
182+
nudged_zero_point = quant_min;
183+
} else if (initial_zero_point > quant_max) {
184+
nudged_zero_point = quant_max;
185+
} else {
186+
nudged_zero_point = std::nearbyint(static_cast<float>(initial_zero_point));
187+
}
188+
189+
// Set output values - use item_mutable() for scalar tensors
190+
scale_out.fill_(scale);
191+
zero_point_out.fill_(nudged_zero_point);
192+
193+
return std::make_tuple(scale_out, zero_point_out);
194+
}
195+
196+
// Forward declaration of implementation functions
197+
void test_vulkan_choose_qparams_tensor_impl(
198+
const std::vector<int>& input_sizes,
199+
int64_t quant_min,
200+
int64_t quant_max,
201+
at::ScalarType dtype,
202+
const vkcompute::utils::StorageType in_storage,
203+
const vkcompute::utils::StorageType out_storage);
204+
205+
// Wrapper function to test both buffer and texture storage types
206+
void test_vulkan_choose_qparams_tensor(
207+
const std::vector<int>& input_sizes,
208+
int64_t quant_min,
209+
int64_t quant_max,
210+
at::ScalarType dtype) {
211+
// Test with buffer storage
212+
test_vulkan_choose_qparams_tensor_impl(
213+
input_sizes,
214+
quant_min,
215+
quant_max,
216+
dtype,
217+
vkcompute::utils::kBuffer,
218+
vkcompute::utils::kBuffer);
219+
220+
// Test with texture storage
221+
test_vulkan_choose_qparams_tensor_impl(
222+
input_sizes,
223+
quant_min,
224+
quant_max,
225+
dtype,
226+
vkcompute::utils::kTexture3D,
227+
vkcompute::utils::kTexture3D);
228+
}
229+
230+
void test_reference_choose_qparams_tensor(
231+
const std::vector<int>& input_sizes,
232+
int64_t quant_min,
233+
int64_t quant_max,
234+
at::ScalarType dtype) {
235+
std::vector<int64_t> input_sizes_int64(
236+
input_sizes.begin(), input_sizes.end());
237+
at::Tensor input =
238+
at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
239+
240+
// Get reference output
241+
auto [reference_scale, reference_zero_point] =
242+
choose_qparams_tensor_reference_impl(input, quant_min, quant_max);
243+
244+
// Get implementation output
245+
auto [impl_scale, impl_zero_point] =
246+
torch::executor::native::choose_qparams_tensor_aten(
247+
input, quant_min, quant_max, dtype);
248+
249+
// Compare outputs
250+
const bool scale_correct = at::allclose(reference_scale, impl_scale);
251+
const bool zero_point_correct =
252+
at::equal(reference_zero_point, impl_zero_point);
253+
254+
if (!scale_correct || !zero_point_correct) {
255+
std::cout << "\n"
256+
<< "Failed with parameters: " << std::endl;
257+
std::cout << " quant_min: " << quant_min << std::endl;
258+
std::cout << " quant_max: " << quant_max << std::endl;
259+
260+
std::cout << "input:" << std::endl;
261+
std::cout << input << std::endl;
262+
std::cout << "reference scale:" << std::endl;
263+
std::cout << reference_scale << std::endl;
264+
std::cout << "implementation scale:" << std::endl;
265+
std::cout << impl_scale << std::endl;
266+
std::cout << "reference zero_point:" << std::endl;
267+
std::cout << reference_zero_point << std::endl;
268+
std::cout << "implementation zero_point:" << std::endl;
269+
std::cout << impl_zero_point << std::endl;
270+
}
271+
272+
ASSERT_TRUE(scale_correct && zero_point_correct);
273+
}
274+
275+
void test_vulkan_choose_qparams_tensor_impl(
276+
const std::vector<int>& input_sizes,
277+
int64_t quant_min,
278+
int64_t quant_max,
279+
at::ScalarType dtype,
280+
const vkcompute::utils::StorageType in_storage,
281+
const vkcompute::utils::StorageType out_storage) {
282+
std::vector<int64_t> input_sizes_int64(
283+
input_sizes.begin(), input_sizes.end());
284+
at::Tensor input =
285+
at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
286+
287+
// Get reference output
288+
auto [reference_scale, reference_zero_point] =
289+
torch::executor::native::choose_qparams_tensor_aten(
290+
input, quant_min, quant_max, dtype);
291+
292+
// Build Vulkan choose_qparams_tensor graph
293+
using namespace vkcompute;
294+
295+
GraphConfig config;
296+
config.set_storage_type_override(in_storage);
297+
ComputeGraph graph(config);
298+
299+
IOValueRef r_input = graph.add_input_tensor(
300+
input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
301+
302+
const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
303+
const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
304+
305+
// Output tensors
306+
const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage);
307+
const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage);
308+
309+
VK_GET_OP_FN("choose_qparams.tensor")
310+
(graph,
311+
{
312+
r_input.value,
313+
r_quant_min,
314+
r_quant_max,
315+
r_scale,
316+
r_zero_point,
317+
});
318+
319+
ValueRef staging_scale = graph.set_output_tensor(r_scale);
320+
ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
321+
322+
graph.prepare();
323+
graph.encode_prepack();
324+
graph.prepack();
325+
graph.encode_execute();
326+
327+
// Run Vulkan choose_qparams_tensor
328+
graph.copy_into_staging(
329+
r_input.staging, input.const_data_ptr(), input.numel());
330+
331+
graph.execute();
332+
333+
// Create output tensors to hold the results - use types that match GPU output
334+
at::Tensor vk_scale =
335+
at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous();
336+
at::Tensor vk_zero_point =
337+
at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous();
338+
339+
// Copy results from GPU to CPU
340+
graph.copy_from_staging(
341+
staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel());
342+
graph.copy_from_staging(
343+
staging_zero_point,
344+
vk_zero_point.mutable_data_ptr(),
345+
vk_zero_point.numel());
346+
347+
// Convert reference values to match Vulkan output types for comparison
348+
at::Tensor reference_scale_float = reference_scale.to(at::kFloat);
349+
at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt);
350+
351+
// Compare outputs
352+
const bool scale_correct = at::allclose(reference_scale_float, vk_scale);
353+
const bool zero_point_correct =
354+
at::equal(reference_zero_point_int, vk_zero_point);
355+
356+
if (!scale_correct || !zero_point_correct) {
357+
std::cout << "\n"
358+
<< "Failed with parameters: " << std::endl;
359+
std::cout << " quant_min: " << quant_min << std::endl;
360+
std::cout << " quant_max: " << quant_max << std::endl;
361+
std::cout << " storage type: "
362+
<< (in_storage == vkcompute::utils::kBuffer ? "buffer"
363+
: "texture")
364+
<< std::endl;
365+
366+
// make sure that there arent a ton of elements in the input tensor
367+
if (input.numel() < 100) {
368+
std::cout << "input:" << std::endl;
369+
std::cout << input << "\n" << std::endl;
370+
std::cout << "reference scale:" << std::endl;
371+
std::cout << reference_scale << std::endl;
372+
std::cout << "vulkan scale:" << std::endl;
373+
std::cout << vk_scale << "\n" << std::endl;
374+
std::cout << "reference zero_point:" << std::endl;
375+
std::cout << reference_zero_point << std::endl;
376+
std::cout << "vulkan zero_point:" << std::endl;
377+
std::cout << vk_zero_point << std::endl;
378+
}
379+
}
380+
381+
ASSERT_TRUE(scale_correct && zero_point_correct);
382+
}
383+
384+
TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
385+
test_reference_choose_qparams_tensor(
386+
{2, 3, 4}, // input sizes
387+
-128, // quant_min
388+
127, // quant_max
389+
at::kChar);
390+
}

0 commit comments

Comments
 (0)