Skip to content

[QAT] Low-bit FSDP all-gather for QAT #1224

Open
@gau-nernst

Description

@gau-nernst

Had this idea and discussed briefly with @andrewor14.

Conceptually the current QAT + FSDP looks like this

  • sharded FP32 weight -> all-gather in BF16 -> fake quantize

However, we can do low-bit all-gather, since the weight can be quantized before all-gather

  • sharded FP32 weight -> (real) quantize -> all-gather in low-bit -> dequantize

In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)

  1. BF16 all-gather + fake quantize
  2. (Real) quantize (1/NGPU) + Low-bit all-gather + Dequant

This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.

This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.

In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions