Open
Description
Hi developers,
Thanks for such a great project!
I want to integrate torchao FP8 GEMM into our training framework. But in my framework, the linear layers are defined in customized modules (where we implement Tensor Parallel or ZeRO3 weight parallel), so it is hard to directly swap the linear layers with torchao Float8Linear
.
So, can FP8 GEMM enabled via a more friendly way, such like module hooks? Since module swapping is not so flexible