diff --git a/pyproject.toml b/pyproject.toml index 8364099..2ea77e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,10 @@ dependencies = [ "sentencepiece", "aiohttp", "pydantic", - "matplotlib" + "matplotlib", + "librosa", + "soundfile", + "datasets", ] classifiers = [ @@ -23,6 +26,8 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] [tool.setuptools] diff --git a/requirements-dev.txt b/requirements-dev.txt index 03bac44..f209ad0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,4 +16,9 @@ pytest-timeout aiohttp openai httpx -vllm \ No newline at end of file +vllm + +# audio +librosa +soundfile +datasets diff --git a/src/flexible_inference_benchmark/engine/backend_functions.py b/src/flexible_inference_benchmark/engine/backend_functions.py index e614cdd..9f8eb66 100644 --- a/src/flexible_inference_benchmark/engine/backend_functions.py +++ b/src/flexible_inference_benchmark/engine/backend_functions.py @@ -21,7 +21,9 @@ class bcolors: class RequestFuncInput(BaseModel): prompt: str - media: List[str] + media: List[str] = Field(default_factory=list) + audio_file_path: Optional[str] = None + language: Optional[str] = None api_url: str prompt_len: int output_len: int @@ -633,6 +635,107 @@ def remove_prefix(text: str, prefix: str) -> str: return text +async def async_request_openai_audio_transcriptions( + idx: int, request_func_input: RequestFuncInput, pbar: Optional[tqdm], verbose: bool, wait_time: float +) -> RequestFuncOutput: + """ + Handle API calls to an OpenAI-compatible audio transcription endpoint. + + This function manages the interaction with audio transcription APIs that follow + the OpenAI API format for audio endpoints (/v1/audio/transcriptions). + """ + api_url = request_func_input.api_url + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + if not request_func_input.audio_file_path or not os.path.exists(request_func_input.audio_file_path): + output.success = False + output.error = f"Audio file not provided or not found: {request_func_input.audio_file_path}" + if pbar: + pbar.update(1) + return output + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT, cookies=request_func_input.cookies) as session: + form = aiohttp.FormData() + form.add_field('model', request_func_input.model) + if request_func_input.language: + form.add_field('language', request_func_input.language) + + try: + audio_file_handle = open(request_func_input.audio_file_path, "rb") + form.add_field('file', + audio_file_handle, + filename=os.path.basename(request_func_input.audio_file_path), + content_type='application/octet-stream') + except IOError as e: + output.success = False + output.error = f"Could not open audio file {request_func_input.audio_file_path}: {e}" + if pbar: + pbar.update(1) + if 'audio_file_handle' in locals() and audio_file_handle: + audio_file_handle.close() + return output + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + latency = 0.0 + + try: + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with session.post(url=api_url, data=form, headers=headers, verify_ssl=request_func_input.ssl) as response: + if response.status == 200: + if request_func_input.stream: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk_text = chunk_bytes.decode("utf-8") + timestamp = time.perf_counter() + if ttft == 0.0 and chunk_text: + ttft = timestamp - st + output.ttft = ttft + elif chunk_text: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += chunk_text + most_recent_timestamp = timestamp + latency = time.perf_counter() - st + else: + resp_json = await response.json() + generated_text = resp_json.get("text", "") + latency = time.perf_counter() - st + output.ttft = latency # For non-streaming, TTFT is the full latency. + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = len(generated_text.split()) + + else: + output.success = False + error_detail = await response.text() + output.error = f"API error: {response.status} {response.reason}. Detail: {error_detail}" + + except aiohttp.ClientConnectorError: + output.success = False + output.error = "Connection error, please verify the server is running and endpoint is correct." + except Exception: # pylint: disable=broad-except + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if 'audio_file_handle' in locals() and audio_file_handle: + audio_file_handle.close() + + if pbar: + pbar.update(1) + return output + + def print_verbose( idx: int, request_func_input: RequestFuncInput, send_time: float, rcv_time: float, latency: float, sending: bool ) -> None: @@ -663,5 +766,6 @@ def print_verbose( "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio_transcriptions, # New backend "tensorrt-llm": async_request_trt_llm, } diff --git a/src/flexible_inference_benchmark/engine/client.py b/src/flexible_inference_benchmark/engine/client.py index fc08df5..3f767bd 100644 --- a/src/flexible_inference_benchmark/engine/client.py +++ b/src/flexible_inference_benchmark/engine/client.py @@ -127,29 +127,16 @@ async def send_wave_request( return request_result async def benchmark( - self, data: List[Tuple[str, int, int]], request_times: List[Union[float, int]], requests_media: List[List[str]] + self, + prepared_requests_data: List[Dict[str, Any]], + request_times: List[Union[float, int]] ) -> list[Union[RequestFuncOutput, Any, None]]: - assert len(data) == len(request_times), "Data and request times must have the same length" - assert len(data) == len(requests_media), "Data and request media must have the same length" - pbar = None if self.disable_tqdm else tqdm(total=len(data)) + assert len(prepared_requests_data) == len(request_times), "Prepared requests data and request times must have the same length" + pbar = None if self.disable_tqdm else tqdm(total=len(prepared_requests_data), desc="Sending Requests") request_func_inputs = [ - RequestFuncInput( - prompt=data_sample[0], - media=media_sample, - api_url=self.api_url, - prompt_len=data_sample[1], - output_len=data_sample[2], - model=self.model_id, - best_of=self.best_of, - use_beam_search=self.use_beam_search, - ssl=self.ssl, - ignore_eos=self.ignore_eos, - stream=self.stream, - cookies=self.cookies, - logprobs=self.logprobs, - ) - for (data_sample, media_sample) in zip(data, requests_media) + RequestFuncInput(**req_data) + for req_data in prepared_requests_data ] if self.wave: @@ -171,22 +158,16 @@ async def benchmark( ] ) - async def validate_url_endpoint( - self, request: Tuple[str, int, int], media_item: List[str] + async def validate_request_func_input( + self, request_input: RequestFuncInput ) -> Union[RequestFuncOutput, Any]: - data = RequestFuncInput( - prompt=request[0], - media=media_item, - api_url=self.api_url, - prompt_len=request[1], - output_len=request[2], - model=self.model_id, - best_of=self.best_of, - use_beam_search=self.use_beam_search, - ssl=self.ssl, - ignore_eos=self.ignore_eos, - stream=self.stream, - cookies=self.cookies, - logprobs=self.logprobs, - ) - return await self.send_request(0, data, 0, None, None) + """ + Validate a request by sending it to the endpoint. + + Args: + request_input: A fully prepared RequestFuncInput instance. + + Returns: + The response from the server. + """ + return await self.send_request(0, request_input, 0, None, None) diff --git a/src/flexible_inference_benchmark/engine/data.py b/src/flexible_inference_benchmark/engine/data.py index ba51876..d35ee7d 100644 --- a/src/flexible_inference_benchmark/engine/data.py +++ b/src/flexible_inference_benchmark/engine/data.py @@ -1,12 +1,18 @@ # pylint: disable=too-many-positional-arguments import abc -from typing import List, Tuple +from typing import List, Tuple, Optional import logging import json import random import os +import tempfile +import shutil +import numpy as np from hashlib import sha256 +import librosa +import soundfile +from datasets import load_dataset # type: ignore[attr-defined] from transformers.tokenization_utils_base import PreTrainedTokenizerBase # type: ignore[attr-defined] from flexible_inference_benchmark.engine import distributions @@ -256,6 +262,169 @@ def generate_data(self, size: int) -> List[Tuple[str, int, int]]: return random.sample(input_data, size) +class ASRDataset(Data): + """ + Dataset class for loading and preparing audio data for ASR benchmarking. + Originally inspired from vLLM's ASR dataset class. + + This class loads audio samples from a Hugging Face dataset, prepares them for + transcription benchmarking, and manages temporary storage of audio files. + """ + DEFAULT_AUDIO_PREAMBLE_TEMPLATE = "<|startoftranscript|><|{lang}|><|transcribe|><|notimestamps|>" + TEXT_FIELD_CANDIDATES = ['text', 'transcription', 'sentence'] + + def __init__( + self, + hf_dataset_name: str, + tokenizer: PreTrainedTokenizerBase, + hf_dataset_config: Optional[str] = None, + hf_dataset_split: str = "train", + language: str = "en", + preamble_template: Optional[str] = None, + audio_duration_limit_sec: Optional[float] = 30.0, + audio_column: str = "audio", + text_column: Optional[str] = None, + max_samples: Optional[int] = None + ): + self.tokenizer = tokenizer + self.hf_dataset_name = hf_dataset_name + self.hf_dataset_config = hf_dataset_config + self.hf_dataset_split = hf_dataset_split + self.language = language + self.preamble = (preamble_template or self.DEFAULT_AUDIO_PREAMBLE_TEMPLATE).format(lang=language) + self.audio_duration_limit_sec = audio_duration_limit_sec + self.audio_column = audio_column + self.text_column = text_column + self.max_samples = max_samples + + self.temp_dir = tempfile.mkdtemp(prefix="fib_audio_cache_") + logger.info(f"Created temporary directory for audio files: {self.temp_dir}") + + self.dataset_samples: List[Tuple[str, int, int, str]] = self._load_and_prepare_data() + + def _find_text_column(self, features) -> str: + """Find the column containing the transcription text.""" + if self.text_column and self.text_column in features: + return self.text_column + for candidate in self.TEXT_FIELD_CANDIDATES: + if candidate in features: + logger.info(f"Using '{candidate}' as text column for ASR dataset.") + return candidate + raise ValueError(f"Could not find a suitable text column (tried: {self.TEXT_FIELD_CANDIDATES}) in dataset features: {list(features.keys())}") + + def _load_and_prepare_data(self) -> List[Tuple[str, int, int, str]]: + """Load and prepare audio data from a Hugging Face dataset.""" + try: + dataset = load_dataset( + self.hf_dataset_name, + name=self.hf_dataset_config, + split=self.hf_dataset_split, + streaming=False, + ) + except Exception as e: + logger.error(f"Failed to load dataset {self.hf_dataset_name}: {e}") + return [] + + if self.max_samples is not None: + dataset = dataset.select(range(min(self.max_samples, len(dataset)))) + + prepared_data = [] + actual_text_column = self._find_text_column(dataset.features) + + logger.info(f"Preparing ASR data from {self.hf_dataset_name} ({self.hf_dataset_split}). This may take a while...") + + processed_count = 0 + skipped_duration = 0 + skipped_missing_text = 0 + + for i, item in enumerate(dataset): + if self.audio_column not in item or not item[self.audio_column]: + logger.warning(f"Skipping item {i} due to missing or empty audio data in column '{self.audio_column}'.") + continue + + audio_data = item[self.audio_column] + + if not isinstance(audio_data, dict) or "array" not in audio_data or "sampling_rate" not in audio_data: + logger.warning(f"Skipping item {i} due to unexpected audio data format: {type(audio_data)}. Expected dict with 'array' and 'sampling_rate'.") + continue + + y = np.array(audio_data["array"]) + sr = audio_data["sampling_rate"] + + if y.ndim > 1: # If stereo, convert to mono. + y = librosa.to_mono(y.T) + + duration_s = librosa.get_duration(y=y, sr=sr) + if self.audio_duration_limit_sec and duration_s > self.audio_duration_limit_sec: + skipped_duration += 1 + continue + + reference_text = item.get(actual_text_column) + if not reference_text or not isinstance(reference_text, str): + skipped_missing_text +=1 + continue + + # Save audio to a temporary WAV file. Using a unique name based on index to avoid collisions if multiple identical audios exist. + temp_audio_filename = os.path.join(self.temp_dir, f"audio_sample_{processed_count}.wav") + try: + soundfile.write(temp_audio_filename, y, sr, format="WAV") + except Exception as e: + logger.error(f"Failed to write temporary audio file for item {i}: {e}") + continue + + prompt_len = len(self.tokenizer.encode(self.preamble)) + output_len = len(self.tokenizer.encode(reference_text)) # Expected output tokens + + prepared_data.append((self.preamble, prompt_len, output_len, temp_audio_filename)) + processed_count += 1 + + if skipped_duration > 0: + logger.info(f"Skipped {skipped_duration} audio samples due to duration limit ({self.audio_duration_limit_sec}s).") + if skipped_missing_text > 0: + logger.info(f"Skipped {skipped_missing_text} audio samples due to missing reference text.") + + logger.info(f"Successfully prepared {len(prepared_data)} ASR samples.") + return prepared_data + + def generate_data(self, size: int) -> List[Tuple[str, int, int, str]]: + """ + Generate random samples from the prepared dataset. + + Returns a list of tuples, where each tuple contains: + (preamble_text, prompt_len, output_len, audio_file_path) + """ + if not self.dataset_samples: + logger.warning("ASR dataset is empty or failed to load. Returning no data.") + return [] + + if len(self.dataset_samples) < size: + logger.warning( + f"Requested {size} samples, but ASR dataset only has {len(self.dataset_samples)}. " + f"Returning all available samples. Consider increasing --max-samples or using a larger dataset split." + ) + return self.dataset_samples + return random.sample(self.dataset_samples, size) + + def cleanup_temp_dir(self): + """Explicitly clean up temporary files.""" + if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): + try: + shutil.rmtree(self.temp_dir) + logger.info(f"Successfully cleaned up temporary audio directory: {self.temp_dir}") + self.temp_dir = None + except OSError as e: + logger.error(f"Error cleaning up temporary directory {self.temp_dir}: {e}") + + def __del__(self): + """Clean up temporary files on object destruction as a fallback.""" + if hasattr(self, 'temp_dir') and self.temp_dir and os.path.exists(self.temp_dir): + logger.warning( + f"ASRDataset temporary directory {self.temp_dir} was not explicitly cleaned. " + "Attempting cleanup in __del__. Please ensure cleanup_temp_dir() is called." + ) + self.cleanup_temp_dir() + + class ShareGPT(Data): def __init__(self, filename: str, tokenizer: PreTrainedTokenizerBase) -> None: # From https://github.com/vllm-project/vllm/blob/v0.4.0.post1/benchmarks/benchmark_serving.py#L310 diff --git a/src/flexible_inference_benchmark/main.py b/src/flexible_inference_benchmark/main.py index 1f1fe0e..81ae128 100644 --- a/src/flexible_inference_benchmark/main.py +++ b/src/flexible_inference_benchmark/main.py @@ -7,7 +7,7 @@ import sys import os import time -from typing import List, Any, Tuple, Union +from typing import List, Any, Dict, Tuple, Union from concurrent.futures import ThreadPoolExecutor import base64 import requests @@ -23,9 +23,9 @@ set_max_open_files, download_sharegpt_dataset, ) -from flexible_inference_benchmark.engine.data import ShareGPT, Textfile, Random +from flexible_inference_benchmark.engine.data import ShareGPT, Textfile, Random, ASRDataset, Data from flexible_inference_benchmark.engine.client import Client -from flexible_inference_benchmark.engine.backend_functions import ASYNC_REQUEST_FUNCS +from flexible_inference_benchmark.engine.backend_functions import ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput from flexible_inference_benchmark.engine.workloads import WORKLOADS_TYPES from flexible_inference_benchmark.data_postprocessors.performance import add_performance_parser, calculate_metrics from flexible_inference_benchmark.data_postprocessors.ttft import add_ttft_parser @@ -158,6 +158,35 @@ def generate_request_times(args: argparse.Namespace) -> List[Union[int, float]]: return [i for i in requests_times if i <= args.max_time_for_reqs] +def create_asr_dataset_and_generate_tasks( + args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase, size: int +) -> Tuple[ASRDataset, List[Tuple[str, int, int, str]]]: + """ + Create an ASRDataset instance and generate audio transcription tasks. + + Args: + args: Command-line arguments + tokenizer: Tokenizer for encoding/decoding text + size: Number of audio tasks to generate + + Returns: + Tuple of (ASRDataset instance, List of task tuples) + where each task tuple contains (preamble_text, prompt_len, output_len, audio_file_path) + """ + asr_dataset = ASRDataset( + hf_dataset_name=args.audio_dataset_name, + tokenizer=tokenizer, + hf_dataset_config=args.audio_dataset_config, + hf_dataset_split=args.audio_dataset_split, + language=args.audio_language, + preamble_template=args.audio_preamble, + audio_duration_limit_sec=args.audio_duration_limit, + max_samples=args.audio_max_samples + ) + data = asr_dataset.generate_data(size) + return asr_dataset, data + + def generate_prompts( args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase, size: int ) -> List[Tuple[str, int, int]]: @@ -225,11 +254,13 @@ def generate_prompts( def send_requests( client: Client, - requests_prompts: List[Tuple[str, int, int]], + prepared_requests_data: List[Dict[str, Any]], requests_times: List[Union[int, float]], - requests_media: List[List[str]], ) -> List[Any]: - return asyncio.run(client.benchmark(requests_prompts, requests_times, requests_media)) + """ + Send prepared requests to the server through the client. + """ + return asyncio.run(client.benchmark(prepared_requests_data, requests_times)) def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> Any: # type: ignore [type-arg] @@ -389,6 +420,16 @@ def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> Any: # t choices=["sharegpt", "sharegpt_code", "other", "random"], help="Name of the dataset to benchmark on.", ) + + # Audio Benchmarking Options + audio_group = benchmark_parser.add_argument_group('Audio Benchmarking Options (used if backend is openai-audio)') + audio_group.add_argument("--audio-dataset-name", type=str, default="librispeech_asr", help="Name of the Hugging Face ASR dataset (e.g., 'librispeech_asr', 'common_voice').") + audio_group.add_argument("--audio-dataset-config", type=str, default=None, help="Configuration for the HF dataset (e.g., 'clean' for librispeech, 'en' for common_voice).") + audio_group.add_argument("--audio-dataset-split", type=str, default="test.clean", help="Dataset split to use (e.g., 'test.clean', 'validation', 'train'). Check HF dataset viewer for available splits.") + audio_group.add_argument("--audio-language", type=str, default="en", help="Target language for ASR (ISO 639-1 code, e.g., 'en', 'es', 'fr'). Used in preamble.") + audio_group.add_argument("--audio-preamble", type=str, default=None, help="Custom preamble template for ASR. Use '{lang}' for language placeholder. Defaults to Whisper-style preamble.") + audio_group.add_argument("--audio-duration-limit", type=float, default=29.5, help="Maximum audio duration in seconds to process. Samples longer than this will be skipped. Default 29.5s (Whisper limit is 30s).") + audio_group.add_argument("--audio-max-samples", type=int, default=None, help="Maximum number of audio samples to load from the dataset for preparation. Useful for large datasets.") benchmark_parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.") @@ -527,8 +568,17 @@ def fail(msg: str) -> None: "Do not specify workload type with ShareGPT dataset." ) - if args.dataset_path and not args.dataset_name: + if args.dataset_path and not args.dataset_name and args.backend != "openai-audio": args.dataset_name = "other" + + if args.backend == "openai-audio": + if not args.audio_dataset_name: + fail("For 'openai-audio' backend, --audio-dataset-name must be provided.") + if args.workload_type: + logger.warning("--workload-type is ignored for 'openai-audio' backend.") + args.workload_type = None + if args.dataset_name != "random" or args.dataset_path: + logger.warning("--dataset-name and --dataset-path are typically for text datasets and might be ignored for 'openai-audio' if --audio-dataset-name is used.") return args @@ -542,83 +592,134 @@ def run_main(args: argparse.Namespace) -> None: np.random.seed(args.seed) random.seed(args.seed) requests_times = generate_request_times(args) - size = len(requests_times) - requests_media = generate_request_media( - args.num_of_imgs_per_req, args.img_ratios_per_req, args.img_base_path, size, args.send_image_with_base64 - ) + num_actual_requests = len(requests_times) + tokenizer_id = args.tokenizer if args.tokenizer else args.model tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(tokenizer_id) - requests_prompts = generate_prompts(args, tokenizer, size) - min_length = min(len(requests_prompts), len(requests_times)) - requests_prompts = requests_prompts[:min_length] - requests_times = requests_times[:min_length] - requests_media = [arr_dims[:min_length] for arr_dims in requests_media] - - set_max_open_files(min_length + 256) - - base_url = args.base_url.strip("/") - endpoint = args.endpoint.strip("/") - args.api_url = f"{base_url}/{endpoint}" - - client = Client( - args.backend, - args.api_url, - args.model, - args.best_of, - args.use_beam_search, - True if args.verbose else args.disable_tqdm, - args.https_ssl, - not args.disable_ignore_eos, - not args.disable_stream, - args.cookies, - args.verbose, - args.max_concurrent, - args.wave, - args.logprobs, - ) - # disable verbose output for validation of the endpoint. This is done to avoid confusion on terminal output. - client_verbose_value = client.verbose - client.verbose = False - logger.info("Sending a single request for validation.") - validate_endpoint = asyncio.run(client.validate_url_endpoint(requests_prompts[0], requests_media[0][0])) - if not validate_endpoint.success: - logger.info(f"{validate_endpoint.error}.\nExiting benchmark ....") - sys.exit(1) - client.verbose = client_verbose_value - logger.info("Beginning benchmark.") - - for idx, arr_dims in enumerate(requests_media): - if args.num_of_imgs_per_req: - logger.info( - ( - f"Benchmarking with {args.num_of_imgs_per_req} images per request " - f"with ratio {args.img_ratios_per_req[idx]}" - ) + + prepared_requests_data: List[Dict[str, Any]] = [] + common_req_params = { + "api_url": f"{args.base_url.strip('/')}/{args.endpoint.strip('/')}", + "model": args.model, + "best_of": args.best_of, + "use_beam_search": args.use_beam_search, + "ssl": args.https_ssl, + "ignore_eos": not args.disable_ignore_eos, + "stream": not args.disable_stream, + "cookies": args.cookies, + "logprobs": args.logprobs, + } + + asr_dataset_instance = None + + try: + if args.backend == "openai-audio": + asr_dataset_instance, audio_tasks = create_asr_dataset_and_generate_tasks(args, tokenizer, num_actual_requests) + num_actual_requests = min(num_actual_requests, len(audio_tasks)) + requests_times = requests_times[:num_actual_requests] + + for i in range(num_actual_requests): + prompt_text, p_len, o_len, audio_fpath = audio_tasks[i] + req_data = { + **common_req_params, + "prompt": prompt_text, + "prompt_len": p_len, + "output_len": o_len, + "audio_file_path": audio_fpath, + "language": args.audio_language, + } + prepared_requests_data.append(req_data) + else: + requests_media_for_images = generate_request_media( + args.num_of_imgs_per_req, args.img_ratios_per_req, args.img_base_path, num_actual_requests, args.send_image_with_base64 ) + + text_prompts = generate_prompts(args, tokenizer, num_actual_requests) + num_actual_requests = min(num_actual_requests, len(text_prompts)) + requests_times = requests_times[:num_actual_requests] + + for i in range(num_actual_requests): + prompt_text, p_len, o_len = text_prompts[i] + req_data = { + **common_req_params, + "prompt": prompt_text, + "prompt_len": p_len, + "output_len": o_len, + "media": requests_media_for_images[0][i] if requests_media_for_images and len(requests_media_for_images[0]) > i else [], + } + prepared_requests_data.append(req_data) + + if not prepared_requests_data: + logger.error("No requests were prepared. Exiting.") + if asr_dataset_instance: + asr_dataset_instance.cleanup_temp_dir() + sys.exit(1) + + num_actual_requests = len(prepared_requests_data) + requests_times = requests_times[:num_actual_requests] + + set_max_open_files(num_actual_requests + 256) + + client = Client( + args.backend, + "", + args.model, + args.best_of, + args.use_beam_search, + True if args.verbose else args.disable_tqdm, + args.https_ssl, + not args.disable_ignore_eos, + not args.disable_stream, + args.cookies, + args.verbose, + args.max_concurrent, + args.wave, + args.logprobs, + ) + + first_prepared_req_for_validation = RequestFuncInput(**prepared_requests_data[0]) + + client_verbose_value = client.verbose + client.verbose = False + logger.info("Sending a single request for validation.") + validate_endpoint = asyncio.run(client.validate_request_func_input(first_prepared_req_for_validation)) + if not validate_endpoint.success: + logger.error(f"Validation request failed: {validate_endpoint.error}.\nExiting benchmark ....") + return + client.verbose = client_verbose_value + logger.info("Beginning benchmark.") + t = time.perf_counter() - output_list: List[Any] = send_requests(client, requests_prompts, requests_times, arr_dims) + output_list: List[Any] = send_requests(client, prepared_requests_data, requests_times) benchmark_time = time.perf_counter() - t - # pylint: disable=line-too-long - output = { + + output_json = { "backend": args.backend, "time": benchmark_time, - "outputs": [request_func_output.model_dump() for request_func_output in output_list], # type: ignore - "inputs": requests_prompts, - "tokenizer": args.tokenizer if args.tokenizer else args.model, + "outputs": [out.model_dump() for out in output_list if out is not None], + "inputs_config": prepared_requests_data, + "tokenizer": tokenizer_id, "stream": not args.disable_stream, } if args.output_file: - filename = args.output_file - if args.num_of_imgs_per_req: - w, h = args.img_ratios_per_req[idx] - filename = f"ratio_{w}x{h}_{filename}" - with open(filename, "w") as f: - f.write(json.dumps(output, indent=4)) # type: ignore + with open(args.output_file, "w") as f: + json.dump(output_json, f, indent=4) if args.debug: - logger.debug(f"{output_list}") + logger.debug(f"Raw outputs: {[out.model_dump() for out in output_list if out is not None]}") - calculate_metrics(output["inputs"], output["outputs"], output["time"], tokenizer, output["stream"]) + simplified_inputs = None + if args.backend == "openai-audio": + simplified_inputs = [(req["prompt"], req["prompt_len"], req["output_len"]) for req in prepared_requests_data] + else: + simplified_inputs = [(req["prompt"], req["prompt_len"], req["output_len"]) for req in prepared_requests_data] + + calculate_metrics(simplified_inputs, output_json["outputs"], output_json["time"], tokenizer, output_json["stream"]) + + finally: + if asr_dataset_instance: + logger.info("Cleaning up ASRDataset temporary directory...") + asr_dataset_instance.cleanup_temp_dir() def main() -> None: