@@ -297,21 +297,66 @@ def test_fp8_weight_dimension_warning(self):
297
297
@unittest .skipIf (
298
298
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
299
299
)
300
- def test_mm_float8dq (self ):
300
+ @common_utils .parametrize (
301
+ "in_features,out_features" , [(512 , 1024 ), (256 , 768 ), (1024 , 512 )]
302
+ )
303
+ @common_utils .parametrize (
304
+ "input_shape" , [(1 , 512 ), (8 , 512 ), (16 , 512 ), (2 , 8 , 512 ), (2 , 2 , 16 , 512 )]
305
+ )
306
+ def test_mm_float8dq (self , in_features , out_features , input_shape ):
301
307
device = "cuda"
302
308
dtype = torch .bfloat16
303
- weight = torch .randn (512 , 1024 ).to (device ).to (dtype )
309
+
310
+ # Adjust input shape to match in_features
311
+ input_shape = list (input_shape )
312
+ input_shape [- 1 ] = in_features
313
+
314
+ weight = torch .randn (in_features , out_features ).to (device ).to (dtype )
304
315
weight = weight .t ()
305
316
306
- l = torch .nn .Linear (512 , 1024 ).to (device ).to (dtype )
307
- l .weight = torch .nn .Parameter (weight )
308
- quantize_ (l , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()))
309
- # weight shape: 1024 x 512
310
- weight = l .weight
317
+ ref_linear = (
318
+ torch .nn .Linear (in_features , out_features , bias = False ).to (device ).to (dtype )
319
+ )
320
+ ref_linear .weight = torch .nn .Parameter (weight .clone ())
321
+
322
+ test_linear = (
323
+ torch .nn .Linear (in_features , out_features , bias = False ).to (device ).to (dtype )
324
+ )
325
+ test_linear .weight = torch .nn .Parameter (weight .clone ())
326
+ quantize_ (
327
+ test_linear , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
328
+ )
329
+
330
+ quant_weight = test_linear .weight
331
+
332
+ self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
333
+ weight_impl = quant_weight .original_weight_tensor .tensor_impl
334
+
335
+ self .assertTrue (hasattr (weight_impl , "float8_data" ))
336
+ self .assertTrue (hasattr (weight_impl , "scale" ))
337
+ self .assertFalse (weight_impl .transposed )
338
+
339
+ # Verify scale shape for row-wise quantization
340
+ expected_scale_shape = (out_features , 1 )
341
+ actual_scale_shape = weight_impl .scale .shape
342
+ self .assertEqual (actual_scale_shape , expected_scale_shape )
343
+
344
+ self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
345
+
346
+ input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
347
+
348
+ with torch .no_grad ():
349
+ ref_output = ref_linear (input_tensor )
350
+ quant_output = torch .nn .functional .linear (input_tensor , quant_weight )
351
+
352
+ expected_output_shape = input_tensor .shape [:- 1 ] + (out_features ,)
353
+ self .assertEqual (quant_output .shape , expected_output_shape )
354
+
355
+ max_abs_error = (ref_output - quant_output ).abs ().max ().item ()
356
+ ref_max = ref_output .abs ().max ().item ()
357
+ relative_error = max_abs_error / ref_max if ref_max > 0 else 0
311
358
312
- input = torch .randn (1 , 512 , device = device , dtype = dtype )
313
- # make sure it runs
314
- torch .nn .functional .linear (input , weight )
359
+ self .assertLess (relative_error , 0.05 )
315
360
316
361
317
362
common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments