diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index 4f2cfa47..f27572a9 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -11,10 +11,12 @@ tracer = trace.get_tracer(__name__) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) +SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2)) -def round_up(number, k): - return (number + k - 1) // k * k +def round_up_seq(number, k, base): + exponent = max(0, math.ceil(math.log(number / k, base))) + return int(k * (base**exponent)) class Batch(ABC): @@ -46,7 +48,9 @@ def from_pb( batch_size = len(pb.cu_seq_lengths) - 1 if device.type == "hpu": # To better utilize HPU, we need to do batch/seq_len bucketing - max_length = round_up(pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF) + max_length = round_up_seq( + pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE + ) max_length = min(max_length, max_input_length) new_bs = 2 ** math.ceil(math.log2(batch_size)) else: diff --git a/backends/src/lib.rs b/backends/src/lib.rs index b86e35df..d333951c 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,6 +39,21 @@ fn powers_of_two(max_value: usize) -> Vec { result } +fn generate_bucket_sizes(bucket_size: usize, max_s: usize, base_exp: usize) -> Vec { + let mut sizes = Vec::new(); + let mut current = bucket_size; + + while current <= max_s { + sizes.push(current); + match current.checked_mul(base_exp) { + Some(next) => current = next, + None => break, + } + } + + sizes +} + fn is_hpu() -> bool { match Command::new("hl-smi") .args(["-Q", "name", "-f", "csv"]) @@ -114,7 +129,7 @@ impl Backend { }; let seq_bucket_size: usize = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); let max_warmup_length: usize = read_env_var("MAX_WARMUP_SEQUENCE_LENGTH", 1024); - + let seq_len_exp_base: usize = read_env_var("SEQ_LEN_EXPONENT_BASE", 2); let max_batch_size = max_bs.unwrap_or_else(|| read_env_var("MAX_WARMUP_BATCH_SIZE", 8)); let mut batch_sizes: Vec = powers_of_two(max_batch_size); @@ -135,9 +150,11 @@ impl Backend { } max_input_length = std::cmp::min(max_input_length, max_warmup_length); - let mut seq_lengths: Vec = (seq_bucket_size..=max_input_length) - .step_by(seq_bucket_size) - .collect(); + let mut seq_lengths: Vec = generate_bucket_sizes( + seq_bucket_size, + max_input_length, + seq_len_exp_base, + ); if let Some(&last) = seq_lengths.last() { if last < max_input_length { seq_lengths.push(max_input_length);