Skip to content
13 changes: 13 additions & 0 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
# Avoid OOM
MAX_NUM_SEQS = 4

# Once we add support for FCG in Mamba1, this list will be removed and tests
# all test cases will use enforce_eager=False
ENFORCE_EAGER_MODELS_V1 = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
]


@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
Expand Down Expand Up @@ -94,13 +101,19 @@ def test_models(
example_prompts, max_tokens, num_logprobs)

if model in V1_SUPPORTED_MODELS:
enforce_eager = False
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")

if model in ENFORCE_EAGER_MODELS_V1:
enforce_eager = True

with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=enforce_eager,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
Expand Down
Loading