Skip to content

Optimize flash bert path for hpu device #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 79 additions & 30 deletions backends/python/server/text_embeddings_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from typing import Type, List
from typing import Type, List, Union
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.models.bert import BertConfig
from opentelemetry import trace
from text_embeddings_server.models import Model
from text_embeddings_server.models.types import FlashBatch, Embedding
from text_embeddings_server.models.types import FlashBatch, Embedding, PaddedBatch
from text_embeddings_server.utils.flash_attn import attention
from text_embeddings_server.utils.device import use_ipex

Expand Down Expand Up @@ -166,22 +166,41 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.num_heads = config.num_attention_heads
self.device = device

def forward(self, hidden_states, cu_seqlens, max_s):
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
residual = hidden_states

qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight)
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=1
)

qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias)
bs = 1
hidden_dim = hidden_states.size(-1)
is_flat = True
if hidden_states.dim() > 2:
is_flat = False
bs = hidden_states.size(0)
q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=2
)
else:
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=1
)
attn_output = torch.empty_like(q)
attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale)
attention(
q,
k,
v,
attn_output,
cu_seqlens,
max_s,
self.softmax_scale,
attn_mask=attn_mask,
)

hidden_states = torch.addmm(
self.dense_bias,
attn_output.view(-1, self.num_heads * self.head_size),
self.dense_weight,
)
if not is_flat:
hidden_states = hidden_states.view(bs, -1, hidden_dim)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)

return hidden_states
Expand Down Expand Up @@ -224,19 +243,16 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
f"{prefix}.output.LayerNorm", handle, device, dtype, config
)

def forward(self, hidden_states, cu_seqlens, max_s):
hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s)
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
hidden_states = self.attention.forward(
hidden_states, cu_seqlens, max_s, attn_mask
)
residual = hidden_states

hidden_states = torch.addmm(
self.intermediate_bias, hidden_states, self.intermediate_weight
hidden_states = F.linear(
hidden_states, self.intermediate_weight.T, self.intermediate_bias
)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = torch.addmm(
self.output_bias,
hidden_states,
self.output_weight,
)
hidden_states = F.linear(hidden_states, self.output_weight.T, self.output_bias)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states

Expand All @@ -248,9 +264,9 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
for i in range(config.num_hidden_layers)
]

def forward(self, hidden_states, cu_seqlens, max_s):
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
for layer in self.layers:
hidden_states = layer.forward(hidden_states, cu_seqlens, max_s)
hidden_states = layer.forward(hidden_states, cu_seqlens, max_s, attn_mask)
return hidden_states


Expand All @@ -259,10 +275,21 @@ def __init__(self, handle, device, dtype, config: BertConfig):
self.embeddings = BertEmbeddings("embeddings", handle, device, dtype, config)
self.encoder = BertEncoder("encoder", handle, device, dtype, config)

def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
def forward(
self,
input_ids,
token_type_ids,
position_ids,
cu_seqlens,
max_s,
mask=None,
attn_mask=None,
):
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s)

encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s, attn_mask)
if mask is not None:
outputs = encoder_outputs[mask]
return outputs[cu_seqlens[:-1]]
return encoder_outputs[cu_seqlens[:-1]]


Expand All @@ -271,6 +298,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
config = BertConfig.from_pretrained(model_path)
with safe_open(model_path / "model.safetensors", framework="pt") as f:
model = FlashBertModel(f, device, dtype, config)
self.device = device
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

Expand All @@ -280,17 +308,38 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)

@property
def batch_type(self) -> Type[FlashBatch]:
return FlashBatch
def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
# for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet
return FlashBatch if self.device.type != "hpu" else PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: FlashBatch) -> List[Embedding]:
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
if isinstance(batch, PaddedBatch):
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
max_input_lens = input_lens.max().item()
cu_seqlens = torch.cat(
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
)
mask = batch.attention_mask.to(torch.bool)
batch_size = input_lens.size(0)
attn_mask = torch.empty(
[batch_size, 1, 1, mask.shape[-1]], device=self.device
).fill_(float("-inf"))
attn_mask[:, :, :, :].masked_fill_(mask[:, None, None, :], 0)
elif isinstance(batch, FlashBatch):
cu_seqlens = batch.cu_seqlens
mask = None
attn_mask = None
max_input_lens = batch.max_s

embedding = self.model.forward(
input_ids=batch.input_ids,
token_type_ids=batch.token_type_ids,
position_ids=batch.position_ids,
cu_seqlens=batch.cu_seqlens,
max_s=batch.max_s,
cu_seqlens=cu_seqlens,
max_s=max_input_lens,
mask=mask,
attn_mask=attn_mask,
)
cpu_results = embedding.view(-1).tolist()

Expand Down
10 changes: 10 additions & 0 deletions backends/python/server/text_embeddings_server/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
import torch
import subprocess

ALLOW_REDUCED_PRECISION = os.getenv(
"ALLOW_REDUCED_PRECISION_FP16_BF16", "true"
).lower() in [
"true",
"1",
]


def _is_ipex_available():
def get_major_and_minor_from_version(full_version):
Expand Down Expand Up @@ -55,6 +62,9 @@ def get_device():
elif is_hpu():
import habana_frameworks.torch.core as htcore

# WA for perf degradation from pytorch 2.5
if ALLOW_REDUCED_PRECISION:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
device = torch.device("hpu")
elif use_ipex():
Expand Down
65 changes: 11 additions & 54 deletions backends/python/server/text_embeddings_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def hpu_attn(
k,
v,
out,
attn_mask,
seqlen_q,
seqlen_k,
max_seqlen_q,
Expand All @@ -71,66 +72,21 @@ def hpu_attn(
):
from habana_frameworks.torch.hpex.kernels import FusedSDPA

total_q, num_head, head_size = q.size()
total_k, num_head_k, _ = k.size()
batch_size = seqlen_q.size(0) - 1
seqlen_q_ = seqlen_q.clone()
seqlen_q_[:batch_size] = seqlen_q[1:]
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
seqlen_k_ = seqlen_k.clone()
seqlen_k_[:batch_size] = seqlen_k[1:]
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]

pad_q = torch.zeros(
[batch_size, max_seqlen_q, num_head, head_size],
dtype=q.dtype,
device=q.device,
)
pad_k = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=k.dtype,
device=k.device,
)
pad_v = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=v.dtype,
device=v.device,
)
q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat(
batch_size, 1
)
q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1))
k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat(
batch_size, 1
)
k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1))
align_mask_seqlen = max_seqlen_k
attn_mask = torch.empty(
[batch_size, 1, 1, align_mask_seqlen],
dtype=q.dtype,
device=q.device,
).fill_(float("-inf"))
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)

pad_q[q_mask] = q
pad_k[k_mask] = k
pad_v[k_mask] = v

pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if is_causal:
attn_mask = None

out_ = FusedSDPA.apply(
pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale
)
out_ = out_.permute(0, 2, 1, 3)
out.copy_(out_[q_mask])
out_ = FusedSDPA.apply(q, k, v, attn_mask, 0.0, is_causal, softmax_scale)
out_ = out_.transpose(1, 2)
out.copy_(out_)
return out


def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
def attention(
q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False, attn_mask=None
):
if HAS_FLASH_ATTN_V2:
if use_ipex:
import intel_extension_for_pytorch as ipex
Expand All @@ -157,6 +113,7 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
k,
v,
out,
attn_mask,
cu_seqlens,
cu_seqlens,
max_s,
Expand Down
Loading