Open
Description
Hey Team,
I'm trying to use FSDP1/2 with Float8InferenceLinear but seems have some issues (with torch 2.3.1+cu118). Do you suggestion to bump to higher version of torch and have a try or maybe use the training setup without using the inference layer? I also tried using the Flont8linear layer without using the quantization function to convert to Float8InferenceLinear but seems face some issues when using FSDP1 that when computing the amax, some input x tensors are empty (x.numel()=0) and some are NaN.
Best regards,
QQ