Skip to content

Conversation

pi314ever
Copy link

What does this PR do?

Minor changes:

  • Proper formatting with cargo fmt and ruff format
  • Refactor warmup_hpu to not include rand dependency and use max_batch_token properly

Bug fix:

  • Resolved import loop in Python where Model was incorrectly imported in default_model.py.

@yao-matrix
Copy link

yao-matrix commented Aug 30, 2024

@pi314ever , I don't prefer to add reformatting things into this PR, since this PR is mainly to add multi-backend design, not the reformat TEI coding-style, mixing things in one PR not comply w/ "one PR, one purpose" rule. if you found there is bug need to fix which is introduced by multi-backend design, pls append a bug-fixing PR to this.

You can submit another reformatting PR to tei repo once this multi-backend PR merged.

@@ -29,8 +29,6 @@ tracing = "0.1"
serde = { version = "1.0", features = ["serde_derive"] }
serde_json = "1.0"
thiserror = "1.0"
rand = "0.8"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it.

@@ -15,7 +15,6 @@ text-embeddings-backend-candle = { path = "candle", optional = true }
text-embeddings-backend-ort = { path = "ort", optional = true }
tokio = { workspace = true }
tracing = { workspace = true }
rand = { workspace = true }

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep it is OK

@@ -40,4 +40,4 @@ typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
__all__.append(FlashBert)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it.

from text_embeddings_server.utils.device import get_device, use_ipex

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it

from optimum.habana.transformers.modeling_utils import (
adapt_transformers_to_gaudi,
)

adapt_transformers_to_gaudi()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

from text_embeddings_server.models.types import PaddedBatch, Embedding

tracer = trace.get_tracer(__name__)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

residual is not None
)
residual is not None,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -269,6 +273,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
model = FlashBertModel(f, device, dtype, config)
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


from text_embeddings_server.models.types import Batch, Embedding

B = TypeVar("B", bound=Batch)


class Model(ABC):
class Model(ABC, Generic[B]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it. Do not use this.

@@ -8,7 +8,7 @@
from pathlib import Path
from typing import Optional

from text_embeddings_server.models import Model, get_model
from text_embeddings_server.models import get_model, Model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -27,6 +32,7 @@ def get_major_and_minor_from_version(full_version):
return False
return True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above is OK.

def use_ipex() -> bool:
value = os.environ.get("USE_IPEX", "True").lower()
return (value in ["true", "1"] and _is_ipex_available())
return value in ["true", "1"] and _is_ipex_available()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add () is better, do not change it

device = torch.device("hpu")
elif use_ipex():
import intel_extension_for_pytorch as ipex

if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

max_s,
softmax_scale,
is_causal=False,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

.output()
{
Ok(output) => output.status.success(),
Err(_) => false,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not change it

@pi314ever
Copy link
Author

@yao-matrix I agree with one PR one purpose, but I believe that the code should maintain formatting guidelines already followed by the upstream repository. All of the changes in this PR deals with code we will add in support of multi-backend.

The changes to exclude the rand package is to resolve inconsistency in how max_batch_tokens is interpreted. In the upstream warmup function, it defines the maximum number of tokens allowed in each batch. However, we were using it as the maximum token ID. The removal of the rand package makes it more consistent with their version of warmup.

The bug fix was simply an import loop that occurs in default model, which is a single line fix.

@@ -99,20 +99,19 @@ impl Backend {
#[instrument(skip(self))]
pub async fn warmup_hpu(
&self,
mut max_input_length: usize,
max_token: usize,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_input_length may be changed in this function.

max_batch_token
)
)
);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Add this check is better.

let max_input_length = std::cmp::min(max_input_length, max_warmup_length);
let mut seq_lengths: Vec<usize> = (seq_bucket_size..max_input_length + 1)
.step_by(seq_bucket_size)
.collect();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -146,7 +160,7 @@ impl Backend {
}
}
for shape in shapes.iter() {
let batch = self.create_warmup_batch(*shape, max_token as u32);
let batch = self.create_warmup_batch(*shape);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it

let input_ids: Vec<u32> = (0..length).map(|_| rand::thread_rng().gen_range(0..max_token)).collect();
let token_type_ids: Vec<u32> = vec![0; length as usize];
let input_ids = vec![0; length as usize];
let token_type_ids = vec![0; length as usize];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it.

return self.warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests).await;
return self
.warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests)
.await;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kaixuanliu
Copy link

@pi314ever , Hi, I add some comments. Some code style re-format is good. It looks better. However, I have some points of suggestions:

  1. It is not recommended to use too heavy extension tools to check code style, some changes like this is not necessary: L11
  2. In this PR, it is better not change original code if not necessary, we can submit new PR to fix it. Example:R22
  3. As for warmup logic, here I want to use random input instead of all zeros. And I need to set a reasonable value(not too big) here for rand::thread_rng().gen_range API. It is not a big problem, we can get feedback from reviewers.

@pi314ever
Copy link
Author

Closing this PR in favor of #8 as the bug is a non-issue.

@pi314ever pi314ever closed this Aug 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants