@@ -114,3 +114,277 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
114
114
} // namespace native
115
115
} // namespace executor
116
116
} // namespace torch
117
+
118
+ //
119
+ // Reference Implementation
120
+ //
121
+
122
+ /*
123
+ * Reference implementation of choose_qparams_tensor
124
+ */
125
+ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl (
126
+ const at::Tensor& input,
127
+ int64_t quant_min,
128
+ int64_t quant_max) {
129
+ // Create output tensors
130
+ at::Tensor scale_out = at::empty ({}, at::device (at::kCPU ).dtype (at::kDouble ));
131
+ at::Tensor zero_point_out =
132
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kLong ));
133
+
134
+ // Find min and max values in the input tensor
135
+ float min_val = input.min ().item <float >();
136
+ float max_val = input.max ().item <float >();
137
+
138
+ // Extend the [min, max] interval to ensure it contains 0
139
+ min_val = std::min (min_val, 0 .f );
140
+ max_val = std::max (max_val, 0 .f );
141
+
142
+ // Calculate scale
143
+ double scale =
144
+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
145
+
146
+ // Handle small scale
147
+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
148
+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
149
+ scale = 0.1 ;
150
+ }
151
+
152
+ if (scale < SMALL_SCALE_THRESHOLD) {
153
+ float org_scale = scale;
154
+ scale = SMALL_SCALE_THRESHOLD;
155
+ // Adjust min and max based on new scale
156
+ if (min_val == 0 .0f ) {
157
+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
158
+ } else if (max_val == 0 .0f ) {
159
+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
160
+ } else {
161
+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
162
+ min_val *= amplifier;
163
+ max_val *= amplifier;
164
+ }
165
+ }
166
+
167
+ // Calculate zero point
168
+ double zero_point_from_min = quant_min - min_val / static_cast <double >(scale);
169
+ double zero_point_from_max = quant_max - max_val / static_cast <double >(scale);
170
+ double zero_point_from_min_error =
171
+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
172
+ double zero_point_from_max_error =
173
+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
174
+ double initial_zero_point =
175
+ zero_point_from_min_error < zero_point_from_max_error
176
+ ? zero_point_from_min
177
+ : zero_point_from_max;
178
+
179
+ // Nudge zero point to be an integer
180
+ int64_t nudged_zero_point = 0 ;
181
+ if (initial_zero_point < quant_min) {
182
+ nudged_zero_point = quant_min;
183
+ } else if (initial_zero_point > quant_max) {
184
+ nudged_zero_point = quant_max;
185
+ } else {
186
+ nudged_zero_point = std::nearbyint (static_cast <float >(initial_zero_point));
187
+ }
188
+
189
+ // Set output values - use item_mutable() for scalar tensors
190
+ scale_out.fill_ (scale);
191
+ zero_point_out.fill_ (nudged_zero_point);
192
+
193
+ return std::make_tuple (scale_out, zero_point_out);
194
+ }
195
+
196
+ // Forward declaration of implementation functions
197
+ void test_vulkan_choose_qparams_tensor_impl (
198
+ const std::vector<int >& input_sizes,
199
+ int64_t quant_min,
200
+ int64_t quant_max,
201
+ at::ScalarType dtype,
202
+ const vkcompute::utils::StorageType in_storage,
203
+ const vkcompute::utils::StorageType out_storage);
204
+
205
+ // Wrapper function to test both buffer and texture storage types
206
+ void test_vulkan_choose_qparams_tensor (
207
+ const std::vector<int >& input_sizes,
208
+ int64_t quant_min,
209
+ int64_t quant_max,
210
+ at::ScalarType dtype) {
211
+ // Test with buffer storage
212
+ test_vulkan_choose_qparams_tensor_impl (
213
+ input_sizes,
214
+ quant_min,
215
+ quant_max,
216
+ dtype,
217
+ vkcompute::utils::kBuffer ,
218
+ vkcompute::utils::kBuffer );
219
+
220
+ // Test with texture storage
221
+ test_vulkan_choose_qparams_tensor_impl (
222
+ input_sizes,
223
+ quant_min,
224
+ quant_max,
225
+ dtype,
226
+ vkcompute::utils::kTexture3D ,
227
+ vkcompute::utils::kTexture3D );
228
+ }
229
+
230
+ void test_reference_choose_qparams_tensor (
231
+ const std::vector<int >& input_sizes,
232
+ int64_t quant_min,
233
+ int64_t quant_max,
234
+ at::ScalarType dtype) {
235
+ std::vector<int64_t > input_sizes_int64 (
236
+ input_sizes.begin (), input_sizes.end ());
237
+ at::Tensor input =
238
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
239
+
240
+ // Get reference output
241
+ auto [reference_scale, reference_zero_point] =
242
+ choose_qparams_tensor_reference_impl (input, quant_min, quant_max);
243
+
244
+ // Get implementation output
245
+ auto [impl_scale, impl_zero_point] =
246
+ torch::executor::native::choose_qparams_tensor_aten (
247
+ input, quant_min, quant_max, dtype);
248
+
249
+ // Compare outputs
250
+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
251
+ const bool zero_point_correct =
252
+ at::equal (reference_zero_point, impl_zero_point);
253
+
254
+ if (!scale_correct || !zero_point_correct) {
255
+ std::cout << " \n "
256
+ << " Failed with parameters: " << std::endl;
257
+ std::cout << " quant_min: " << quant_min << std::endl;
258
+ std::cout << " quant_max: " << quant_max << std::endl;
259
+
260
+ std::cout << " input:" << std::endl;
261
+ std::cout << input << std::endl;
262
+ std::cout << " reference scale:" << std::endl;
263
+ std::cout << reference_scale << std::endl;
264
+ std::cout << " implementation scale:" << std::endl;
265
+ std::cout << impl_scale << std::endl;
266
+ std::cout << " reference zero_point:" << std::endl;
267
+ std::cout << reference_zero_point << std::endl;
268
+ std::cout << " implementation zero_point:" << std::endl;
269
+ std::cout << impl_zero_point << std::endl;
270
+ }
271
+
272
+ ASSERT_TRUE (scale_correct && zero_point_correct);
273
+ }
274
+
275
+ void test_vulkan_choose_qparams_tensor_impl (
276
+ const std::vector<int >& input_sizes,
277
+ int64_t quant_min,
278
+ int64_t quant_max,
279
+ at::ScalarType dtype,
280
+ const vkcompute::utils::StorageType in_storage,
281
+ const vkcompute::utils::StorageType out_storage) {
282
+ std::vector<int64_t > input_sizes_int64 (
283
+ input_sizes.begin (), input_sizes.end ());
284
+ at::Tensor input =
285
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
286
+
287
+ // Get reference output
288
+ auto [reference_scale, reference_zero_point] =
289
+ torch::executor::native::choose_qparams_tensor_aten (
290
+ input, quant_min, quant_max, dtype);
291
+
292
+ // Build Vulkan choose_qparams_tensor graph
293
+ using namespace vkcompute ;
294
+
295
+ GraphConfig config;
296
+ config.set_storage_type_override (in_storage);
297
+ ComputeGraph graph (config);
298
+
299
+ IOValueRef r_input = graph.add_input_tensor (
300
+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
301
+
302
+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
303
+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
304
+
305
+ // Output tensors
306
+ const ValueRef r_scale = graph.add_tensor ({}, vkapi::kFloat , out_storage);
307
+ const ValueRef r_zero_point = graph.add_tensor ({}, vkapi::kInt , out_storage);
308
+
309
+ VK_GET_OP_FN (" choose_qparams.tensor" )
310
+ (graph,
311
+ {
312
+ r_input.value ,
313
+ r_quant_min,
314
+ r_quant_max,
315
+ r_scale,
316
+ r_zero_point,
317
+ });
318
+
319
+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
320
+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
321
+
322
+ graph.prepare ();
323
+ graph.encode_prepack ();
324
+ graph.prepack ();
325
+ graph.encode_execute ();
326
+
327
+ // Run Vulkan choose_qparams_tensor
328
+ graph.copy_into_staging (
329
+ r_input.staging , input.const_data_ptr (), input.numel ());
330
+
331
+ graph.execute ();
332
+
333
+ // Create output tensors to hold the results - use types that match GPU output
334
+ at::Tensor vk_scale =
335
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kFloat )).contiguous ();
336
+ at::Tensor vk_zero_point =
337
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kInt )).contiguous ();
338
+
339
+ // Copy results from GPU to CPU
340
+ graph.copy_from_staging (
341
+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
342
+ graph.copy_from_staging (
343
+ staging_zero_point,
344
+ vk_zero_point.mutable_data_ptr (),
345
+ vk_zero_point.numel ());
346
+
347
+ // Convert reference values to match Vulkan output types for comparison
348
+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
349
+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
350
+
351
+ // Compare outputs
352
+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
353
+ const bool zero_point_correct =
354
+ at::equal (reference_zero_point_int, vk_zero_point);
355
+
356
+ if (!scale_correct || !zero_point_correct) {
357
+ std::cout << " \n "
358
+ << " Failed with parameters: " << std::endl;
359
+ std::cout << " quant_min: " << quant_min << std::endl;
360
+ std::cout << " quant_max: " << quant_max << std::endl;
361
+ std::cout << " storage type: "
362
+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
363
+ : " texture" )
364
+ << std::endl;
365
+
366
+ // make sure that there arent a ton of elements in the input tensor
367
+ if (input.numel () < 100 ) {
368
+ std::cout << " input:" << std::endl;
369
+ std::cout << input << " \n " << std::endl;
370
+ std::cout << " reference scale:" << std::endl;
371
+ std::cout << reference_scale << std::endl;
372
+ std::cout << " vulkan scale:" << std::endl;
373
+ std::cout << vk_scale << " \n " << std::endl;
374
+ std::cout << " reference zero_point:" << std::endl;
375
+ std::cout << reference_zero_point << std::endl;
376
+ std::cout << " vulkan zero_point:" << std::endl;
377
+ std::cout << vk_zero_point << std::endl;
378
+ }
379
+ }
380
+
381
+ ASSERT_TRUE (scale_correct && zero_point_correct);
382
+ }
383
+
384
+ TEST (VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
385
+ test_reference_choose_qparams_tensor (
386
+ {2 , 3 , 4 }, // input sizes
387
+ -128 , // quant_min
388
+ 127 , // quant_max
389
+ at::kChar );
390
+ }
0 commit comments