Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile-intel
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url
RUN cd backends/python/server && \
make install

FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest AS hpu
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

Expand Down
11 changes: 5 additions & 6 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_model(model_path: Path, dtype: Optional[str]) :
raise RuntimeError(f"Unknown dtype {dtype}")

device = get_device()
logger.info(f"backend device: {device}")
config = AutoConfig.from_pretrained(model_path)
if config.model_type == "bert":
config: BertConfig
Expand All @@ -48,14 +49,12 @@ def get_model(model_path: Path, dtype: Optional[str]) :
):
return FlashBert(model_path, device, datatype) # type: ignore
if use_ipex() and device.type in ["cpu", "xpu"]:
import intel_extension_for_pytorch as ipex
return FlashBert(model_path, device, datatype) # type: ignore
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
model_handle = DefaultModel(model_path, device, datatype)
model_handle.model = wrap_in_hpu_graph(model_handle.model, disable_tensor_cache=True)
return model_handle
import habana_frameworks.torch.core as htcore
return FlashBert(model_path, device, datatype)

return DefaultModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype)
39 changes: 33 additions & 6 deletions backends/python/server/text_embeddings_server/models/flash_bert.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,43 @@
import torch

from pathlib import Path
from torch import nn
import torch.nn.functional as F
from typing import Type, List
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.models.bert import BertConfig
from opentelemetry import trace


from text_embeddings_server.models import Model
from text_embeddings_server.models.types import FlashBatch, Embedding
from text_embeddings_server.utils.flash_attn import attention
from text_embeddings_server.utils.device import use_ipex

tracer = trace.get_tracer(__name__)

def hpu_add_layer_norm(
add: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
add_back: bool
):
if add is not None:
added_tensor = torch.add(add, x, alpha=1.0)
output = F.layer_norm(added_tensor, [x.size(-1)], weight, bias, epsilon)
if add_back:
add.add_(x)
return output
else:
return F.layer_norm(x, [x.size(-1)], weight=weight, bias=bias, eps=epsilon)

class FastLayerNorm:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
self.variance_epsilon = config.layer_norm_eps
self.device = device
self.use_ipex = use_ipex()

def forward(self, hidden_states, residual=None):
# Flash attention imports
Expand All @@ -48,7 +64,7 @@ def forward(self, hidden_states, residual=None):
)
if res is None:
res = hidden_states
elif use_ipex():
elif self.use_ipex:
import intel_extension_for_pytorch as ipex
normed_hidden_states = ipex.llm.functional.add_layer_norm(
residual,
Expand All @@ -60,7 +76,16 @@ def forward(self, hidden_states, residual=None):
)

res = residual if residual is not None else hidden_states

elif self.device.type == "hpu":
normed_hidden_states = hpu_add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None
)
res = residual if residual is not None else hidden_states
return normed_hidden_states, res


Expand Down Expand Up @@ -242,7 +267,9 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
config = BertConfig.from_pretrained(model_path)
with safe_open(model_path / "model.safetensors", framework="pt") as f:
model = FlashBertModel(f, device, dtype, config)

if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
Expand Down
6 changes: 3 additions & 3 deletions backends/python/server/text_embeddings_server/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def get_major_and_minor_from_version(full_version):
return False
return True

def _is_hpu() -> bool:
def is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
except:
is_hpu_available = False
return is_hpu_available

Expand All @@ -43,7 +43,7 @@ def get_device() :
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif _is_hpu():
elif is_hpu():
import habana_frameworks.torch.core as htcore
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
device = torch.device("hpu")
Expand Down
78 changes: 72 additions & 6 deletions backends/python/server/text_embeddings_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import torch
from text_embeddings_server.utils.device import use_ipex
from text_embeddings_server.utils.device import use_ipex, is_hpu

from loguru import logger

Expand All @@ -10,7 +10,10 @@
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False

if use_ipex():
is_hpu = is_hpu()
use_ipex = use_ipex()

if use_ipex or is_hpu:
HAS_FLASH_ATTN_V2 = True
else:
if not torch.cuda.is_available():
Expand Down Expand Up @@ -54,14 +57,77 @@
HAS_FLASH_ATTN = True


def hpu_attn(q, k, v, out, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal=False):
from habana_frameworks.torch.hpex.kernels import FusedSDPA
total_q, num_head, head_size = q.size()
total_k, num_head_k, _ = k.size()
batch_size = seqlen_q.size(0) - 1
seqlen_q_ = seqlen_q.clone()
seqlen_q_[:batch_size] = seqlen_q[1:]
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
seqlen_k_ = seqlen_k.clone()
seqlen_k_[:batch_size] = seqlen_k[1:]
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]

pad_q = torch.zeros(
[batch_size, max_seqlen_q, num_head, head_size],
dtype=q.dtype,
device=q.device,
)
pad_k = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=k.dtype,
device=k.device,
)
pad_v = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=v.dtype,
device=v.device,
)
q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat(
batch_size, 1
)
q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1))
k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat(
batch_size, 1
)
k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1))
align_mask_seqlen = max_seqlen_k
attn_mask = torch.empty(
[batch_size, 1, 1, align_mask_seqlen],
dtype=q.dtype,
device=q.device,
).fill_(float("-inf"))
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)

pad_q[q_mask] = q
pad_k[k_mask] = k
pad_v[k_mask] = v

pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
if is_causal:
attn_mask = None

out_ = FusedSDPA.apply(pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale)
out_ = out_.permute(0, 2, 1, 3)
out.copy_(out_[q_mask])
return out


def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
if HAS_FLASH_ATTN_V2:
if use_ipex():
if use_ipex:
import intel_extension_for_pytorch as ipex
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, 0, softmax_scale,
zero_tensors=False, is_causal=False,
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, 0, softmax_scale,
zero_tensors=False, is_causal=False,
return_softmax=False, gen_=None)
elif is_hpu:
return hpu_attn(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, softmax_scale, is_causal=False)

else:
return flash_attn_2_cuda.varlen_fwd(
q,
Expand Down
Loading