Open
Description
Had this idea and discussed briefly with @andrewor14.
Conceptually the current QAT + FSDP looks like this
- sharded FP32 weight -> all-gather in BF16 -> fake quantize
However, we can do low-bit all-gather, since the weight can be quantized before all-gather
- sharded FP32 weight -> (real) quantize -> all-gather in low-bit -> dequantize
In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)
- BF16 all-gather + fake quantize
- (Real) quantize (1/NGPU) + Low-bit all-gather + Dequant
This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.
This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.
In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)
- A tensor subclass to hold original weight + use FSDP2 all-gather extension: possibly extend this https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/quantization/qat/affine_fake_quantized_tensor.py
- Another tensor subclass to hold quantized weight. If AQT has basic support for backward, maybe we can use AQT directly. Otherwise, need to have another subclass.