Skip to content

Modify FLOPs in MFU calculation for casual mask when using FlashAttention. #341

@Yuxin-CV

Description

@Yuxin-CV

Hi, I suggest we modify the FLOPs calculation in the MFU according to the FlashAttention benchmark script.

Specifically, the current calculation for the casual mask can exceed 100% MFU for seq_len = 16k (189 * 2 / 312 = 1.21), which is inaccurate. The FLOPs for the casual mask setting should be divided by 2 when using FlashAttention.

flash2_a100_fwd_bwd_benchmark

Metadata

Metadata

Assignees

No one assigned

    Labels

    duplicateThis issue or pull request already exists

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions