-
Notifications
You must be signed in to change notification settings - Fork 0
Formatting, Refactor, and fix Python backend import loop #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pi314ever , I don't prefer to add reformatting things into this PR, since this PR is mainly to add multi-backend design, not the reformat TEI coding-style, mixing things in one PR not comply w/ "one PR, one purpose" rule. if you found there is bug need to fix which is introduced by multi-backend design, pls append a bug-fixing PR to this. You can submit another reformatting PR to tei repo once this multi-backend PR merged. |
@@ -29,8 +29,6 @@ tracing = "0.1" | |||
serde = { version = "1.0", features = ["serde_derive"] } | |||
serde_json = "1.0" | |||
thiserror = "1.0" | |||
rand = "0.8" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it.
@@ -15,7 +15,6 @@ text-embeddings-backend-candle = { path = "candle", optional = true } | |||
text-embeddings-backend-ort = { path = "ort", optional = true } | |||
tokio = { workspace = true } | |||
tracing = { workspace = true } | |||
rand = { workspace = true } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep it is OK
@@ -40,4 +40,4 @@ typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" | |||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" | |||
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" | |||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" | |||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
except ImportError as e: | ||
logger.warning(f"Could not import Flash Attention enabled models: {e}") | ||
FLASH_ATTENTION = False | ||
|
||
if FLASH_ATTENTION: | ||
__all__.append(FlashBert) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it.
from text_embeddings_server.utils.device import get_device, use_ipex | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it
from optimum.habana.transformers.modeling_utils import ( | ||
adapt_transformers_to_gaudi, | ||
) | ||
|
||
adapt_transformers_to_gaudi() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
from text_embeddings_server.models.types import PaddedBatch, Embedding | ||
|
||
tracer = trace.get_tracer(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
residual is not None | ||
) | ||
residual is not None, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -269,6 +273,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |||
model = FlashBertModel(f, device, dtype, config) | |||
if device.type == "hpu": | |||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph | |||
|
|||
model = wrap_in_hpu_graph(model, disable_tensor_cache=False) | |||
self.hidden_size = config.hidden_size | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
from text_embeddings_server.models.types import Batch, Embedding | ||
|
||
B = TypeVar("B", bound=Batch) | ||
|
||
|
||
class Model(ABC): | ||
class Model(ABC, Generic[B]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it. Do not use this.
@@ -8,7 +8,7 @@ | |||
from pathlib import Path | |||
from typing import Optional | |||
|
|||
from text_embeddings_server.models import Model, get_model | |||
from text_embeddings_server.models import get_model, Model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -27,6 +32,7 @@ def get_major_and_minor_from_version(full_version): | |||
return False | |||
return True | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above is OK.
def use_ipex() -> bool: | ||
value = os.environ.get("USE_IPEX", "True").lower() | ||
return (value in ["true", "1"] and _is_ipex_available()) | ||
return value in ["true", "1"] and _is_ipex_available() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add ()
is better, do not change it
device = torch.device("hpu") | ||
elif use_ipex(): | ||
import intel_extension_for_pytorch as ipex | ||
|
||
if hasattr(torch, "xpu") and torch.xpu.is_available(): | ||
device = torch.device("xpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
max_s, | ||
softmax_scale, | ||
is_causal=False, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
.output() | ||
{ | ||
Ok(output) => output.status.success(), | ||
Err(_) => false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not change it
@yao-matrix I agree with one PR one purpose, but I believe that the code should maintain formatting guidelines already followed by the upstream repository. All of the changes in this PR deals with code we will add in support of multi-backend. The changes to exclude the rand package is to resolve inconsistency in how The bug fix was simply an import loop that occurs in default model, which is a single line fix. |
@@ -99,20 +99,19 @@ impl Backend { | |||
#[instrument(skip(self))] | |||
pub async fn warmup_hpu( | |||
&self, | |||
mut max_input_length: usize, | |||
max_token: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_input_length
may be changed in this function.
max_batch_token | ||
) | ||
) | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Add this check is better.
let max_input_length = std::cmp::min(max_input_length, max_warmup_length); | ||
let mut seq_lengths: Vec<usize> = (seq_bucket_size..max_input_length + 1) | ||
.step_by(seq_bucket_size) | ||
.collect(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -146,7 +160,7 @@ impl Backend { | |||
} | |||
} | |||
for shape in shapes.iter() { | |||
let batch = self.create_warmup_batch(*shape, max_token as u32); | |||
let batch = self.create_warmup_batch(*shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it
let input_ids: Vec<u32> = (0..length).map(|_| rand::thread_rng().gen_range(0..max_token)).collect(); | ||
let token_type_ids: Vec<u32> = vec![0; length as usize]; | ||
let input_ids = vec![0; length as usize]; | ||
let token_type_ids = vec![0; length as usize]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it.
return self.warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests).await; | ||
return self | ||
.warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests) | ||
.await; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@pi314ever , Hi, I add some comments. Some code style re-format is good. It looks better. However, I have some points of suggestions:
|
Closing this PR in favor of #8 as the bug is a non-issue. |
What does this PR do?
Minor changes:
cargo fmt
andruff format
warmup_hpu
to not include rand dependency and usemax_batch_token
properlyBug fix:
Model
was incorrectly imported indefault_model.py
.