Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,38 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

compare_all_settings(
model,
[arg1, arg2],
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
)


def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
and compare the results of the API calls with the first set of arguments.
Args:
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
"""

trust_remote_code = False
for args in (arg1, arg2):
for args in all_args:
if "--trust-remote-code" in args:
trust_remote_code = True
break

tokenizer_mode = "auto"
for args in (arg1, arg2):
for args in all_args:
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
Expand All @@ -330,8 +354,10 @@ def compare_two_settings(model: str,

prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
results = []
for args, env in ((arg1, env1), (arg2, env2)):
ref_results: List = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
compare_results: List = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
env_dict=env,
Expand All @@ -355,13 +381,20 @@ def compare_two_settings(model: str,
else:
assert_never(method)

n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if i > 0:
# if any setting fails, raise an error early
ref_args = all_args[0]
ref_envs = all_envs[0]
compare_args = all_args[i]
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
f"{compare_args=} {compare_envs=}\n"
f"{ref_result=}\n"
f"{compare_result=}\n")


def init_test_distributed_environment(
Expand Down
Loading