Skip to content

[TL] initial implement flashattention op in TL #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

import warnings
Expand Down
205 changes: 205 additions & 0 deletions bitblas/ops/general_flashatten/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas.base.roller.hint import Hint
from tvm.target import Target
from .tilelang import select_scheduler as consistent_scheduler
from ..base_scheduler import BaseScheduler
from ..operator import OperatorConfig, Operator, BaseKernelNameGenerator
from ...base.arch.cuda import CUDA
from ...utils import auto_detect_nvidia_target
from dataclasses import dataclass
from typing import Union, Tuple, Literal, Optional
import logging
import torch

logger = logging.getLogger(__name__)

WORKSPACE_SIZE = 1024 * 1024 * 256


def is_native_compute(Q_dtype, K_dtype, V_dtype) -> bool:
return Q_dtype == K_dtype and K_dtype == V_dtype


@dataclass(frozen=True)
class FlashAttenConfig(OperatorConfig):
batch: Union[int, Tuple[int]] = None
# TODO should distinguish from q_heads and kv_heads
heads: Optional[int] = None
kv_heads: Optional[int] = None
seq_len: Optional[int] = None
dim: Optional[int] = None
Q_dtype: str = "float16"
K_dtype: str = Q_dtype # for default
V_dtype: str = Q_dtype
Accu_dtype: str = "float32"
Out_dtype: str = "float16"
layout: Literal["nnn", "ntn"] = "nnn"
is_causal: bool = False


class FlashAttenKernelNameGenerator(BaseKernelNameGenerator):

KERNEL_PREFIX = "flashatten"

def is_valid_config(self, config: OperatorConfig) -> bool:
return isinstance(config, FlashAttenConfig)

@staticmethod
def simplify_dtype(dtype: str) -> str:
if dtype.startswith("float"):
return f"f{dtype[5:]}"
elif dtype.startswith("bfloat"):
return f"bf{dtype[6:]}"
elif dtype.startswith("int"):
return f"i{dtype[3:]}"
elif dtype.startswith("uint"):
return f"u{dtype[4:]}"
else:
raise ValueError("Currently only support float, bfloat, int, uint")

def generate(self, hint: Hint = None) -> str:
config = self.config
kernel_name = self.KERNEL_PREFIX
shape_str = f"batch{self.config.batch}heads{self.config.heads}seqlen{self.config.seq_len}dim{self.config.dim}"
Q_dtype = self.simplify_dtype(config.Q_dtype)
K_dtype = self.simplify_dtype(config.K_dtype)
V_dtype = self.simplify_dtype(config.V_dtype)
Accu_dtype = self.simplify_dtype(config.Accu_dtype)
Out_dtype = self.simplify_dtype(config.Out_dtype)
precision_str = f"Q{Q_dtype}_K{K_dtype}_V{V_dtype}_Accu{Accu_dtype}_Out{Out_dtype}"
kernel_name = "_".join([kernel_name, shape_str, precision_str])
# TODO need to add hint
assert self.is_valid(kernel_name), "Kernel name invalid"
return kernel_name


class FlashAtten(Operator):

BITBLAS_TRICK_DTYPE_MAP = {
"float32": ("fp", 32),
"float16": ("fp", 16),
"int8": ("int", 8),
"int4": ("int", 4),
}

def __init__(
self,
config: FlashAttenConfig,
name: str = "flashatten",
target: Optional[Union[str, Target]] = None,
enable_tuning: bool = True,
from_database: bool = False,
backend: str = "tl",
):
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")

assert (config.Q_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.Q_dtype}"
assert (config.K_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.K_dtype}"
assert (config.V_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.V_dtype}"
assert backend == "tl", "FlashAttention only support TL compiler"

source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.Q_dtype]

self.source_format = source_format
self.bit = bit
self.backend = backend
super().__init__(name, config, target, backend)

target = self.target
if target.kind.name != "cuda":
raise ValueError("Currently only support cuda target")

self.dispatch_tl(target, from_database, source_format, enable_tuning)

def dispatch_tl(self,
target: Target,
from_database: bool = False,
source_format: str = "fp16",
enable_tuning: bool = True):
self.arch = CUDA(target)
if not from_database:
self._build_default_module(target)
self.workspace = None
if enable_tuning:
self.hardware_aware_finetune()
self.torch_output_dtype = getattr(torch, self.Out_dtype)

def get_kernel_name_generator(self):
return FlashAttenKernelNameGenerator(self.config)

def _alloc_workspace(self):
return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda()

def _free_workspace(self):
# release the workspace if it is None
if self.workspace is not None:
self.workspace = None

def _select_scheduler(self) -> Optional[BaseScheduler]:
if is_native_compute(self.Q_dtype, self.K_dtype, self.V_dtype):
return consistent_scheduler(
batch=self.batch,
heads=self.heads,
seq_len=self.seq_len,
dim=self.dim,
layout=self.layout,
dtype_QKV=self.Q_dtype,
dtype_Out=self.Out_dtype,
dtype_Accu=self.Accu_dtype,
is_causal=self.is_causal,
)
else:
raise ValueError("Currently only support native compute for scheduler")

def cleanup(self):
self._free_workspace()

@property
def batch(self):
return self.config.batch

@property
def heads(self):
return self.config.heads

@property
def seq_len(self):
return self.config.seq_len

@property
def dim(self):
return self.config.dim

@property
def Q_dtype(self):
return self.config.Q_dtype

@property
def K_dtype(self):
return self.config.K_dtype

@property
def V_dtype(self):
return self.config.V_dtype

@property
def Accu_dtype(self):
return self.config.Accu_dtype

@property
def Out_dtype(self):
return self.config.Out_dtype

@property
def layout(self):
return self.config.layout

@property
def is_causal(self):
return self.config.is_causal
38 changes: 38 additions & 0 deletions bitblas/ops/general_flashatten/tilelang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .flashatten import flashatten_blocked # noqa: F401
from .flashatten import FlashAttenScheduler # noqa: F401


def parse_layout(layout: str):
trans_Q = False
trans_K = layout[1] == 't'
trans_V = False
return trans_Q, trans_K, trans_V


def select_scheduler(
batch=None,
heads=None,
seq_len=None,
dim=None,
layout="nnn",
dtype_QKV="float16",
dtype_Out="float16",
dtype_Accu="float32",
is_causal=False,
):
trans_list = parse_layout(layout)
trans_K = trans_list[1]
return FlashAttenScheduler(
batch=batch,
heads=heads,
seq_len=seq_len,
dim=dim,
trans_K=trans_K,
dtype_QKV=dtype_QKV,
dtype_Out=dtype_Out,
dtype_Accu=dtype_Accu,
is_causal=is_causal,
)
Loading
Loading