From 72726b78e028a14bd9a7d9bf2cbd2b4b55ec7373 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:25:06 +0100 Subject: [PATCH 1/5] Remove unnecessary file from root dir Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .yapfignore | 1 - pyproject.toml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 .yapfignore diff --git a/.yapfignore b/.yapfignore deleted file mode 100644 index 2d6dcf8380ca..000000000000 --- a/.yapfignore +++ /dev/null @@ -1 +0,0 @@ -collect_env.py diff --git a/pyproject.toml b/pyproject.toml index 85a112ff51cf..7bd1c59a0931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ ignore_patterns = [ "benchmarks/**", "build/**", "examples/**", + "vllm/collect_env.py", ] [tool.ruff] From ceb7cc265fbe5a9a0147467a85f59c852d5e0eb3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:25:06 +0100 Subject: [PATCH 2/5] Migrate `tests/` from `yapf` to `ruff-format` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .pre-commit-config.yaml | 17 +++++++++++-- pyproject.toml | 2 ++ tests/pyproject.toml | 54 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 tests/pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5197820fb402..43418dd8fd79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,27 @@ repos: - id: yapf args: [--in-place, --verbose] # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' + exclude: | + (?x)^( + .buildkite| + benchmarks| + build| + examples| + tests + )/.*$ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* + files: | + (?x)^( + .buildkite| + benchmarks| + examples| + tests + )/.*$ - repo: https://github.com/crate-ci/typos rev: v1.34.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 7bd1c59a0931..3ce2bb1c09a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ ignore_patterns = [ "benchmarks/**", "build/**", "examples/**", + "tests/**", "vllm/collect_env.py", ] @@ -143,6 +144,7 @@ skip_glob = [ ".buildkite/*", "benchmarks/*", "examples/*", + "tests/*", ] use_parentheses = true skip_gitignore = true diff --git a/tests/pyproject.toml b/tests/pyproject.toml new file mode 100644 index 000000000000..f825cb203269 --- /dev/null +++ b/tests/pyproject.toml @@ -0,0 +1,54 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 +exclude = [ + # External file, leaving license intact + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" +] + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm"] + +[tool.ruff.format] +docstring-code-format = true \ No newline at end of file From e4c86294e6962379ca3f9fe671d9dedbe5076689 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 28 Jul 2025 11:31:47 +0100 Subject: [PATCH 3/5] `pre-commit run -a` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/async_engine/api_server_async_engine.py | 14 +- tests/async_engine/conftest.py | 2 +- tests/async_engine/test_api_server.py | 31 +- tests/async_engine/test_async_llm_engine.py | 60 +- tests/async_engine/test_request_tracker.py | 3 +- .../test_basic_correctness.py | 133 +- .../basic_correctness/test_chunked_prefill.py | 136 +- tests/basic_correctness/test_cpu_offload.py | 5 +- tests/basic_correctness/test_cumem.py | 23 +- tests/basic_correctness/test_preemption.py | 87 +- tests/benchmarks/test_latency_cli.py | 14 +- tests/benchmarks/test_serve_cli.py | 4 +- tests/benchmarks/test_throughput_cli.py | 14 +- tests/build_cython.py | 15 +- tests/compile/backend.py | 17 +- .../compile/piecewise/test_full_cudagraph.py | 108 +- tests/compile/piecewise/test_simple.py | 59 +- tests/compile/piecewise/test_toy_llama.py | 290 ++-- tests/compile/test_async_tp.py | 139 +- tests/compile/test_basic_correctness.py | 42 +- tests/compile/test_config.py | 55 +- tests/compile/test_full_graph.py | 111 +- tests/compile/test_functionalization.py | 62 +- tests/compile/test_fusion.py | 80 +- tests/compile/test_fusion_all_reduce.py | 118 +- tests/compile/test_fusion_attn.py | 57 +- tests/compile/test_pass_manager.py | 9 +- tests/compile/test_sequence_parallelism.py | 213 +-- tests/compile/test_silu_mul_quant_fusion.py | 50 +- tests/compile/test_wrapper.py | 13 +- tests/config/test_config_generation.py | 10 +- tests/config/test_mp_reducer.py | 14 +- tests/conftest.py | 414 +++--- tests/core/block/e2e/conftest.py | 36 +- tests/core/block/e2e/test_correctness.py | 385 +++--- .../e2e/test_correctness_sliding_window.py | 86 +- tests/core/block/test_block_manager.py | 182 +-- tests/core/block/test_block_table.py | 201 +-- .../block/test_cpu_gpu_block_allocator.py | 31 +- tests/core/block/test_naive_block.py | 96 +- tests/core/block/test_prefix_caching_block.py | 382 +++--- tests/core/conftest.py | 2 +- tests/core/test_chunked_prefill_scheduler.py | 173 ++- tests/core/test_num_computed_tokens_update.py | 49 +- tests/core/test_scheduler.py | 518 ++++---- tests/core/test_scheduler_encoder_decoder.py | 51 +- tests/core/test_serialization.py | 14 +- tests/core/utils.py | 109 +- tests/cuda/test_cuda_context.py | 45 +- .../test_disable_detokenization.py | 20 +- tests/detokenizer/test_stop_checker.py | 32 +- tests/detokenizer/test_stop_reason.py | 37 +- tests/detokenizer/test_stop_strings.py | 100 +- tests/distributed/conftest.py | 6 +- tests/distributed/test_ca_buffer_sharing.py | 6 +- tests/distributed/test_comm_ops.py | 106 +- tests/distributed/test_custom_all_reduce.py | 45 +- tests/distributed/test_distributed_oot.py | 3 +- tests/distributed/test_eplb_algo.py | 178 +-- tests/distributed/test_eplb_execute.py | 181 +-- tests/distributed/test_events.py | 53 +- tests/distributed/test_expert_parallel.py | 51 +- .../distributed/test_multi_node_assignment.py | 10 +- tests/distributed/test_node_count.py | 15 +- tests/distributed/test_pipeline_parallel.py | 215 +-- tests/distributed/test_pipeline_partition.py | 4 +- tests/distributed/test_pp_cudagraph.py | 20 +- tests/distributed/test_pynccl.py | 234 ++-- tests/distributed/test_quick_all_reduce.py | 74 +- tests/distributed/test_same_node.py | 6 +- tests/distributed/test_sequence_parallel.py | 139 +- tests/distributed/test_shm_broadcast.py | 24 +- tests/distributed/test_torchrun_example.py | 30 +- tests/distributed/test_utils.py | 67 +- tests/encoder_decoder/test_e2e_correctness.py | 52 +- tests/engine/conftest.py | 2 +- tests/engine/test_arg_utils.py | 175 +-- tests/engine/test_computed_prefix_blocks.py | 12 +- tests/engine/test_executor.py | 22 +- .../test_multi_step_output_processor.py | 79 +- tests/engine/test_multiproc_workers.py | 20 +- tests/engine/test_options.py | 15 +- tests/engine/test_short_mm_context.py | 21 +- tests/entrypoints/conftest.py | 139 +- tests/entrypoints/llm/test_accuracy.py | 19 +- tests/entrypoints/llm/test_chat.py | 85 +- tests/entrypoints/llm/test_collective_rpc.py | 12 +- tests/entrypoints/llm/test_encode.py | 43 +- tests/entrypoints/llm/test_generate.py | 39 +- .../llm/test_generate_multiple_loras.py | 23 +- tests/entrypoints/llm/test_gpu_utilization.py | 5 +- tests/entrypoints/llm/test_guided_generate.py | 377 +++--- tests/entrypoints/llm/test_lazy_outlines.py | 40 +- .../entrypoints/llm/test_prompt_validation.py | 4 +- .../offline_mode/test_offline_mode.py | 10 +- .../openai/correctness/test_lmeval.py | 26 +- .../openai/correctness/test_mteb_embed.py | 14 +- .../openai/correctness/test_mteb_score.py | 22 +- .../test_transcription_api_correctness.py | 58 +- .../openai/test_async_tokenization.py | 50 +- tests/entrypoints/openai/test_audio.py | 274 ++-- tests/entrypoints/openai/test_basic.py | 58 +- tests/entrypoints/openai/test_chat.py | 802 ++++++------ tests/entrypoints/openai/test_chat_echo.py | 21 +- .../openai/test_chat_logit_bias_validation.py | 10 +- .../entrypoints/openai/test_chat_template.py | 74 +- .../openai/test_chat_with_tool_reasoning.py | 101 +- .../entrypoints/openai/test_chunked_prompt.py | 27 +- .../entrypoints/openai/test_classification.py | 76 +- tests/entrypoints/openai/test_cli_args.py | 119 +- tests/entrypoints/openai/test_completion.py | 399 +++--- .../test_completion_with_function_calling.py | 80 +- .../test_completion_with_prompt_embeds.py | 61 +- .../openai/test_default_mm_loras.py | 43 +- tests/entrypoints/openai/test_embedding.py | 230 ++-- .../openai/test_embedding_dimensions.py | 36 +- .../openai/test_encoder_decoder.py | 10 +- .../entrypoints/openai/test_lora_adapters.py | 175 ++- .../entrypoints/openai/test_lora_resolvers.py | 58 +- tests/entrypoints/openai/test_metrics.py | 137 +- tests/entrypoints/openai/test_models.py | 4 +- .../openai/test_oot_registration.py | 11 +- .../entrypoints/openai/test_openai_schema.py | 19 +- .../openai/test_optional_middleware.py | 12 +- tests/entrypoints/openai/test_pooling.py | 175 +-- .../openai/test_prompt_validation.py | 35 +- tests/entrypoints/openai/test_rerank.py | 67 +- .../openai/test_return_tokens_as_ids.py | 42 +- tests/entrypoints/openai/test_root_path.py | 37 +- tests/entrypoints/openai/test_run_batch.py | 124 +- tests/entrypoints/openai/test_score.py | 146 ++- tests/entrypoints/openai/test_serving_chat.py | 168 +-- .../entrypoints/openai/test_serving_models.py | 61 +- tests/entrypoints/openai/test_shutdown.py | 9 +- tests/entrypoints/openai/test_sleep.py | 27 +- .../openai/test_tensorizer_entrypoint.py | 29 +- tests/entrypoints/openai/test_tokenization.py | 182 ++- .../openai/test_transcription_validation.py | 107 +- .../openai/test_translation_validation.py | 60 +- tests/entrypoints/openai/test_truncation.py | 23 +- tests/entrypoints/openai/test_video.py | 276 ++-- tests/entrypoints/openai/test_vision.py | 294 ++--- .../openai/test_vision_embedding.py | 45 +- .../test_hunyuan_a13b_tool_parser.py | 190 +-- .../test_llama4_pythonic_tool_parser.py | 249 ++-- .../tool_parsers/test_pythonic_tool_parser.py | 189 ++- .../entrypoints/openai/tool_parsers/utils.py | 92 +- .../test_api_server_process_manager.py | 68 +- tests/entrypoints/test_chat_utils.py | 1162 +++++++---------- tests/entrypoints/test_ssl_cert_refresher.py | 3 +- .../test_fastsafetensors_loader.py | 3 +- .../test_weight_utils.py | 26 +- tests/kernels/allclose_default.py | 6 +- tests/kernels/attention/conftest.py | 3 +- tests/kernels/attention/test_attention.py | 199 ++- .../attention/test_attention_selector.py | 205 +-- .../attention/test_blocksparse_attention.py | 83 +- tests/kernels/attention/test_cache.py | 410 +++--- .../attention/test_cascade_flash_attn.py | 62 +- .../attention/test_encoder_decoder_attn.py | 427 +++--- tests/kernels/attention/test_flash_attn.py | 132 +- tests/kernels/attention/test_flashinfer.py | 285 ++-- ...test_flashinfer_trtllm_decode_attention.py | 56 +- tests/kernels/attention/test_flashmla.py | 60 +- .../kernels/attention/test_lightning_attn.py | 100 +- .../attention/test_merge_attn_states.py | 218 ++-- tests/kernels/attention/test_mha_attn.py | 19 +- .../kernels/attention/test_mla_decode_cpu.py | 34 +- .../kernels/attention/test_prefix_prefill.py | 394 +++--- .../attention/test_rocm_attention_selector.py | 48 +- .../attention/test_triton_decode_attention.py | 12 +- .../test_triton_unified_attention.py | 62 +- tests/kernels/core/test_activation.py | 43 +- .../core/test_fused_quant_layernorm.py | 118 +- tests/kernels/core/test_layernorm.py | 72 +- tests/kernels/core/test_opcheck.py | 1 - tests/kernels/core/test_permute_cols.py | 8 +- tests/kernels/core/test_pos_encoding.py | 227 ++-- tests/kernels/core/test_rotary_embedding.py | 80 +- tests/kernels/core/test_uva.py | 18 +- tests/kernels/mamba/test_causal_conv1d.py | 260 ++-- tests/kernels/mamba/test_mamba_mixer2.py | 85 +- tests/kernels/mamba/test_mamba_ssm.py | 595 +++++---- tests/kernels/mamba/test_mamba_ssm_ssd.py | 215 ++- .../moe/modular_kernel_tools/cli_args.py | 95 +- .../moe/modular_kernel_tools/common.py | 258 ++-- .../make_feature_matrix.py | 124 +- .../moe/modular_kernel_tools/mk_objects.py | 83 +- .../modular_kernel_tools/parallel_utils.py | 28 +- .../profile_modular_kernel.py | 53 +- .../kernels/moe/modular_kernel_tools/utils.py | 56 +- tests/kernels/moe/parallel_utils.py | 105 +- tests/kernels/moe/test_batched_moe.py | 106 +- tests/kernels/moe/test_block_fp8.py | 126 +- tests/kernels/moe/test_block_int8.py | 50 +- .../moe/test_count_expert_num_tokens.py | 114 +- .../kernels/moe/test_cutlass_grouped_gemm.py | 71 +- tests/kernels/moe/test_cutlass_moe.py | 352 ++--- tests/kernels/moe/test_deepep_deepgemm_moe.py | 306 +++-- tests/kernels/moe/test_deepep_moe.py | 284 ++-- tests/kernels/moe/test_deepgemm.py | 48 +- .../moe/test_modular_kernel_combinations.py | 119 +- tests/kernels/moe/test_moe.py | 440 ++++--- .../kernels/moe/test_moe_align_block_size.py | 54 +- .../kernels/moe/test_moe_permute_unpermute.py | 243 ++-- tests/kernels/moe/test_mxfp4_moe.py | 42 +- tests/kernels/moe/test_nvfp4_moe.py | 110 +- tests/kernels/moe/test_pplx_cutlass_moe.py | 204 +-- tests/kernels/moe/test_pplx_moe.py | 320 +++-- tests/kernels/moe/test_rocm_aiter_topk.py | 199 +-- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 16 +- tests/kernels/moe/test_triton_moe_ptpc_fp8.py | 42 +- tests/kernels/moe/utils.py | 141 +- tests/kernels/quant_utils.py | 162 +-- tests/kernels/quantization/nvfp4_utils.py | 14 +- .../quantization/test_allspark_gemm.py | 80 +- tests/kernels/quantization/test_aqlm.py | 43 +- tests/kernels/quantization/test_awq.py | 52 +- tests/kernels/quantization/test_awq_triton.py | 105 +- tests/kernels/quantization/test_block_fp8.py | 56 +- tests/kernels/quantization/test_block_int8.py | 23 +- .../quantization/test_cutlass_2of4_sparse.py | 179 ++- .../quantization/test_cutlass_scaled_mm.py | 589 +++++---- tests/kernels/quantization/test_fp8_quant.py | 97 +- tests/kernels/quantization/test_ggml.py | 53 +- tests/kernels/quantization/test_gguf.py | 108 +- tests/kernels/quantization/test_gptq.py | 29 +- .../kernels/quantization/test_int8_kernel.py | 44 +- tests/kernels/quantization/test_int8_quant.py | 98 +- tests/kernels/quantization/test_machete_mm.py | 300 +++-- .../kernels/quantization/test_marlin_gemm.py | 322 +++-- .../kernels/quantization/test_nvfp4_quant.py | 44 +- .../quantization/test_nvfp4_scaled_mm.py | 85 +- .../quantization/test_rocm_skinny_gemms.py | 22 +- .../quantization/test_triton_scaled_mm.py | 55 +- .../test_apply_repetition_penalties.py | 64 +- tests/kernels/test_cutlass_mla_decode.py | 47 +- tests/kernels/test_flex_attention.py | 13 +- tests/kernels/test_fused_quant_activation.py | 20 +- tests/kernels/test_triton_flash_attention.py | 530 ++++---- tests/kernels/utils.py | 649 ++++----- tests/kv_transfer/test_disagg.py | 47 +- tests/kv_transfer/test_lookup_buffer.py | 23 +- tests/kv_transfer/test_module.py | 25 +- tests/kv_transfer/test_send_recv.py | 30 +- tests/lora/conftest.py | 131 +- tests/lora/test_add_lora.py | 47 +- tests/lora/test_baichuan.py | 81 +- tests/lora/test_chatglm3_tp.py | 66 +- tests/lora/test_default_mm_loras.py | 13 +- tests/lora/test_layers.py | 837 ++++++------ tests/lora/test_llama_tp.py | 179 ++- tests/lora/test_lora_allowed_token_ids.py | 32 +- tests/lora/test_lora_checkpoints.py | 42 +- tests/lora/test_lora_functions.py | 44 +- tests/lora/test_lora_huggingface.py | 11 +- tests/lora/test_lora_manager.py | 462 ++++--- tests/lora/test_minicpmv_tp.py | 53 +- tests/lora/test_mixtral.py | 29 +- tests/lora/test_peft_helper.py | 31 +- tests/lora/test_phi.py | 39 +- tests/lora/test_punica_ops.py | 236 ++-- tests/lora/test_quant_model.py | 80 +- tests/lora/test_qwen2vl.py | 127 +- tests/lora/test_resolver.py | 11 +- tests/lora/test_tokenizer_group.py | 30 +- tests/lora/test_transformers_model.py | 76 +- tests/lora/test_utils.py | 115 +- tests/lora/test_worker.py | 45 +- tests/lora/utils.py | 92 +- tests/metrics/test_metrics.py | 166 +-- tests/mistral_tool_use/conftest.py | 10 +- .../test_mistral_tool_calls.py | 6 +- tests/mistral_tool_use/utils.py | 13 +- tests/model_executor/conftest.py | 41 +- .../model_executor/test_enabled_custom_ops.py | 122 +- .../model_executor/test_guided_processors.py | 154 +-- tests/model_executor/test_logits_processor.py | 49 +- .../test_model_load_with_params.py | 59 +- tests/model_executor/test_weight_utils.py | 39 +- tests/models/language/generation/test_bart.py | 97 +- .../models/language/generation/test_common.py | 60 +- .../models/language/generation/test_gemma.py | 12 +- .../language/generation/test_granite.py | 6 +- .../models/language/generation/test_hybrid.py | 125 +- .../language/generation/test_mistral.py | 333 +++-- .../models/language/generation/test_phimoe.py | 78 +- tests/models/language/pooling/embed_utils.py | 31 +- tests/models/language/pooling/mteb_utils.py | 154 +-- tests/models/language/pooling/test_baai.py | 118 +- .../pooling/test_bge_reranker_v2_gemma.py | 72 +- .../language/pooling/test_classification.py | 20 +- .../language/pooling/test_cross_encoder.py | 17 +- .../models/language/pooling/test_embedding.py | 51 +- tests/models/language/pooling/test_gritlm.py | 25 +- tests/models/language/pooling/test_gte.py | 106 +- .../models/language/pooling/test_intfloat.py | 54 +- tests/models/language/pooling/test_jina.py | 78 +- .../language/pooling/test_mxbai_rerank.py | 56 +- tests/models/language/pooling/test_nomic.py | 40 +- .../pooling/test_nomic_max_model_len.py | 69 +- .../language/pooling/test_qwen3_reranker.py | 69 +- tests/models/language/pooling/test_reward.py | 22 +- tests/models/language/pooling/test_scoring.py | 63 +- .../pooling/test_snowflake_arctic_embed.py | 92 +- .../pooling/test_truncation_control.py | 52 +- .../multimodal/generation/test_common.py | 197 ++- .../multimodal/generation/test_florence2.py | 83 +- .../generation/test_granite_speech.py | 73 +- .../multimodal/generation/test_interleaved.py | 38 +- .../multimodal/generation/test_mllama.py | 482 ++++--- .../multimodal/generation/test_phi4mm.py | 177 +-- .../multimodal/generation/test_pixtral.py | 171 +-- .../multimodal/generation/test_qwen2_vl.py | 279 ++-- .../multimodal/generation/test_ultravox.py | 146 ++- .../multimodal/generation/test_voxtral.py | 70 +- .../multimodal/generation/test_whisper.py | 33 +- .../generation/vlm_utils/builders.py | 168 +-- .../generation/vlm_utils/case_filtering.py | 94 +- .../multimodal/generation/vlm_utils/core.py | 63 +- .../generation/vlm_utils/custom_inputs.py | 35 +- .../generation/vlm_utils/model_utils.py | 257 ++-- .../generation/vlm_utils/runners.py | 99 +- .../multimodal/generation/vlm_utils/types.py | 25 +- .../multimodal/pooling/test_dse_qwen2_vl.py | 117 +- .../multimodal/pooling/test_intern_vit.py | 26 +- .../pooling/test_jinavl_reranker.py | 78 +- .../multimodal/pooling/test_llava_next.py | 53 +- tests/models/multimodal/pooling/test_phi3v.py | 47 +- .../multimodal/processing/test_common.py | 74 +- .../multimodal/processing/test_h2ovl.py | 26 +- .../multimodal/processing/test_idefics3.py | 14 +- .../multimodal/processing/test_internvl.py | 13 +- .../multimodal/processing/test_llama4.py | 37 +- .../multimodal/processing/test_llava_next.py | 47 +- .../processing/test_llava_onevision.py | 56 +- .../processing/test_minimax_vl_01.py | 26 +- .../multimodal/processing/test_mllama.py | 10 +- .../multimodal/processing/test_nemotron_vl.py | 16 +- .../multimodal/processing/test_phi3v.py | 4 +- .../multimodal/processing/test_phi4mm.py | 10 +- .../multimodal/processing/test_qwen2_vl.py | 6 +- .../multimodal/processing/test_smolvlm.py | 14 +- tests/models/multimodal/test_mapping.py | 7 +- tests/models/quantization/test_aqlm.py | 37 +- tests/models/quantization/test_awq.py | 87 +- tests/models/quantization/test_bitblas.py | 26 +- .../models/quantization/test_bitsandbytes.py | 236 ++-- tests/models/quantization/test_fp8.py | 107 +- tests/models/quantization/test_gguf.py | 72 +- .../models/quantization/test_gptq_bitblas.py | 20 +- tests/models/quantization/test_gptq_marlin.py | 50 +- .../quantization/test_gptq_marlin_24.py | 41 +- tests/models/quantization/test_modelopt.py | 37 +- tests/models/quantization/test_mxfp4.py | 21 +- tests/models/quantization/test_nvfp4.py | 43 +- tests/models/registry.py | 6 +- tests/models/test_initialization.py | 85 +- tests/models/test_oot_registration.py | 41 +- tests/models/test_registry.py | 75 +- tests/models/test_transformers.py | 79 +- tests/models/test_utils.py | 38 +- tests/models/test_vision.py | 24 +- tests/models/utils.py | 127 +- tests/mq_llm_engine/conftest.py | 2 +- tests/mq_llm_engine/test_abort.py | 18 +- tests/mq_llm_engine/test_error_handling.py | 158 +-- tests/mq_llm_engine/test_load.py | 18 +- tests/mq_llm_engine/utils.py | 36 +- .../multi_step/test_correctness_async_llm.py | 75 +- tests/multi_step/test_correctness_llm.py | 152 ++- tests/multimodal/test_inputs.py | 34 +- tests/multimodal/test_processing.py | 34 +- tests/multimodal/test_utils.py | 164 ++- tests/multimodal/test_video.py | 21 +- tests/multimodal/utils.py | 4 +- tests/neuron/1_core/test_activation.py | 16 +- tests/neuron/1_core/test_block_table.py | 35 +- tests/neuron/1_core/test_cache.py | 34 +- tests/neuron/1_core/test_layernorm.py | 27 +- tests/neuron/1_core/test_logits_processor.py | 45 +- .../neuron/1_core/test_neuron_model_runner.py | 35 +- tests/neuron/1_core/test_neuron_quant.py | 3 +- tests/neuron/1_core/test_prefix_prefill.py | 177 ++- tests/neuron/1_core/test_rotary_embedding.py | 56 +- tests/neuron/2_core/test_comm_ops.py | 50 +- tests/neuron/2_core/test_eagle.py | 29 +- tests/neuron/2_core/test_mistral.py | 29 +- tests/neuron/2_core/test_multi_lora.py | 100 +- .../test_filesystem_resolver.py | 5 +- tests/plugins/vllm_add_dummy_model/setup.py | 15 +- .../vllm_add_dummy_model/__init__.py | 3 +- .../my_gemma_embedding.py | 15 +- .../vllm_add_dummy_model/my_llava.py | 23 +- .../vllm_add_dummy_model/my_opt.py | 5 +- .../plugins/vllm_add_dummy_platform/setup.py | 16 +- .../dummy_attention_backend.py | 4 +- .../dummy_custom_ops.py | 3 +- .../vllm_add_dummy_platform/dummy_platform.py | 14 +- tests/plugins_tests/conftest.py | 2 +- tests/plugins_tests/test_platform_plugins.py | 17 +- tests/plugins_tests/test_scheduler_plugins.py | 10 +- .../test_disable_sliding_window.py | 28 +- tests/prefix_caching/test_prefix_caching.py | 75 +- tests/prompt_adapter/test_bloom.py | 38 +- .../test_multi_adapter_inference.py | 60 +- tests/prompt_adapter/test_pa_lora.py | 46 +- tests/quantization/reference_mxfp4.py | 125 +- tests/quantization/test_auto_round.py | 21 +- tests/quantization/test_compressed_tensors.py | 215 +-- tests/quantization/test_configs.py | 24 +- tests/quantization/test_cpu_offload.py | 183 +-- tests/quantization/test_experts_int8.py | 13 +- tests/quantization/test_fp8.py | 95 +- tests/quantization/test_gptq_dynamic.py | 51 +- tests/quantization/test_ipex_quant.py | 18 +- tests/quantization/test_lm_head.py | 29 +- tests/quantization/test_ptpc_fp8.py | 32 +- tests/quantization/test_quark.py | 105 +- .../test_register_quantization_config.py | 54 +- tests/quantization/test_rtn.py | 10 +- tests/quantization/test_torchao.py | 58 +- .../test_deepseekr1_reasoning_parser.py | 17 +- .../test_granite_reasoning_parser.py | 48 +- .../test_hunyuan_reasoning_parser.py | 35 +- .../reasoning/test_qwen3_reasoning_parser.py | 14 +- tests/reasoning/utils.py | 14 +- .../test_weight_utils.py | 15 +- tests/samplers/test_beam_search.py | 41 +- tests/samplers/test_ignore_eos.py | 6 +- tests/samplers/test_logits_processor.py | 5 +- tests/samplers/test_logprobs.py | 105 +- tests/samplers/test_no_bad_words.py | 106 +- tests/samplers/test_ranks.py | 24 +- tests/samplers/test_rejection_sampler.py | 442 ++++--- tests/samplers/test_sampler.py | 301 +++-- tests/samplers/test_seeded_generate.py | 20 +- .../test_typical_acceptance_sampler.py | 248 ++-- tests/spec_decode/conftest.py | 2 +- tests/spec_decode/e2e/conftest.py | 208 +-- tests/spec_decode/e2e/test_compatibility.py | 21 +- .../spec_decode/e2e/test_eagle_correctness.py | 709 +++++----- tests/spec_decode/e2e/test_integration.py | 192 +-- .../e2e/test_integration_dist_tp2.py | 459 ++++--- .../e2e/test_integration_dist_tp4.py | 156 ++- tests/spec_decode/e2e/test_logprobs.py | 345 +++-- .../e2e/test_medusa_correctness.py | 567 ++++---- tests/spec_decode/e2e/test_mlp_correctness.py | 720 +++++----- tests/spec_decode/e2e/test_mtp_correctness.py | 467 ++++--- .../e2e/test_multistep_correctness.py | 1050 ++++++++------- .../spec_decode/e2e/test_ngram_correctness.py | 567 ++++---- tests/spec_decode/e2e/test_seed.py | 45 +- tests/spec_decode/test_batch_expansion.py | 68 +- tests/spec_decode/test_dynamic_spec_decode.py | 64 +- tests/spec_decode/test_memory_usage.py | 33 +- tests/spec_decode/test_metrics.py | 143 +- tests/spec_decode/test_multi_step_worker.py | 468 ++++--- tests/spec_decode/test_ngram_worker.py | 76 +- tests/spec_decode/test_scorer.py | 119 +- tests/spec_decode/test_spec_decode_worker.py | 840 ++++++------ tests/spec_decode/test_utils.py | 70 +- tests/spec_decode/utils.py | 181 +-- tests/standalone_tests/lazy_imports.py | 3 +- tests/tensorizer_loader/conftest.py | 21 +- tests/tensorizer_loader/test_tensorizer.py | 248 ++-- tests/test_cache_block_hashing.py | 43 +- tests/test_config.py | 189 +-- tests/test_embedded_commit.py | 16 +- tests/test_inputs.py | 29 +- tests/test_logger.py | 68 +- tests/test_outputs.py | 16 +- tests/test_regression.py | 25 +- tests/test_sampling_params.py | 69 +- tests/test_scalartype.py | 33 +- tests/test_seed_behavior.py | 48 +- tests/test_sequence.py | 20 +- tests/test_sharded_state_loader.py | 88 +- tests/test_triton_utils.py | 7 +- tests/test_utils.py | 314 +++-- tests/test_version.py | 3 +- tests/test_vllm_port.py | 13 +- tests/tokenization/test_cached_tokenizer.py | 19 +- tests/tokenization/test_detokenize.py | 281 ++-- tests/tokenization/test_get_eos.py | 7 +- tests/tokenization/test_mistral_tokenizer.py | 335 ++--- tests/tokenization/test_tokenizer.py | 2 +- tests/tokenization/test_tokenizer_group.py | 15 +- tests/tokenization/test_tokenizer_registry.py | 30 +- tests/tool_use/conftest.py | 30 +- ...est_chat_completion_request_validations.py | 104 +- tests/tool_use/test_chat_completions.py | 38 +- tests/tool_use/test_jamba_tool_parser.py | 248 ++-- tests/tool_use/test_kimi_k2_tool_parser.py | 110 +- tests/tool_use/test_minimax_tool_parser.py | 173 +-- tests/tool_use/test_parallel_tool_calls.py | 73 +- tests/tool_use/test_tool_calls.py | 37 +- tests/tool_use/test_tool_choice_required.py | 347 +++-- tests/tool_use/test_xlam_tool_parser.py | 178 ++- tests/tool_use/utils.py | 410 +++--- tests/tools/test_config_validator.py | 23 +- tests/tpu/lora/test_lora.py | 73 +- tests/tpu/test_compilation.py | 29 +- tests/tpu/test_custom_dispatcher.py | 31 +- tests/tpu/test_moe_pallas.py | 9 +- tests/tpu/test_quantization_accuracy.py | 14 +- tests/tracing/test_tracing.py | 162 ++- tests/utils.py | 564 ++++---- tests/v1/attention/test_attention_backends.py | 269 ++-- tests/v1/attention/utils.py | 133 +- tests/v1/core/test_async_scheduler.py | 55 +- tests/v1/core/test_kv_cache_utils.py | 385 +++--- tests/v1/core/test_prefix_caching.py | 632 +++++---- tests/v1/core/test_scheduler.py | 521 ++++---- tests/v1/core/test_scheduler_e2e.py | 16 +- tests/v1/core/test_specialized_manager.py | 95 +- tests/v1/core/utils.py | 65 +- tests/v1/e2e/test_cascade_attention.py | 3 +- .../v1/e2e/test_correctness_sliding_window.py | 48 +- tests/v1/e2e/test_spec_decode.py | 52 +- tests/v1/engine/conftest.py | 44 +- tests/v1/engine/test_async_llm.py | 132 +- tests/v1/engine/test_engine_args.py | 12 +- tests/v1/engine/test_engine_core.py | 55 +- tests/v1/engine/test_engine_core_client.py | 173 ++- .../v1/engine/test_fast_incdec_prefix_err.py | 145 +- tests/v1/engine/test_llm_engine.py | 52 +- tests/v1/engine/test_output_processor.py | 504 +++---- tests/v1/engine/utils.py | 73 +- tests/v1/entrypoints/conftest.py | 114 +- .../llm/test_struct_output_generate.py | 345 ++--- .../entrypoints/openai/responses/conftest.py | 1 - .../openai/responses/test_basic.py | 43 +- .../openai/responses/test_image.py | 117 +- .../openai/responses/test_stateful.py | 18 +- .../responses/test_structured_output.py | 24 +- .../openai/test_chat_completion.py | 73 +- .../v1/entrypoints/openai/test_completion.py | 377 +++--- .../openai/test_multi_api_servers.py | 89 +- tests/v1/executor/test_multiproc_executor.py | 77 +- .../nixl_integration/test_accuracy.py | 33 +- .../nixl_integration/test_edge_cases.py | 39 +- .../nixl_integration/toy_proxy_server.py | 166 ++- .../kv_connector/unit/test_multi_connector.py | 114 +- .../kv_connector/unit/test_nixl_connector.py | 230 ++-- .../unit/test_remote_decode_lifecycle.py | 45 +- .../unit/test_remote_prefill_lifecycle.py | 136 +- tests/v1/kv_connector/unit/utils.py | 116 +- tests/v1/metrics/test_ray_metrics.py | 8 +- tests/v1/sample/test_logits_processors.py | 392 +++--- tests/v1/sample/test_logprobs.py | 184 +-- tests/v1/sample/test_logprobs_e2e.py | 30 +- tests/v1/sample/test_rejection_sampler.py | 250 ++-- tests/v1/sample/test_sampler.py | 178 +-- tests/v1/sample/test_sampling_params_e2e.py | 22 +- tests/v1/sample/test_topk_topp_sampler.py | 43 +- tests/v1/sample/utils.py | 42 +- tests/v1/shutdown/test_delete.py | 50 +- tests/v1/shutdown/test_forward_error.py | 54 +- tests/v1/shutdown/test_processor_error.py | 19 +- tests/v1/shutdown/test_startup_error.py | 47 +- tests/v1/spec_decode/test_eagle.py | 160 +-- tests/v1/spec_decode/test_ngram.py | 104 +- tests/v1/structured_output/test_utils.py | 153 +-- tests/v1/test_async_llm_dp.py | 81 +- tests/v1/test_external_lb_dp.py | 149 ++- tests/v1/test_metrics_reader.py | 44 +- tests/v1/test_oracle.py | 12 +- tests/v1/test_serial_utils.py | 80 +- tests/v1/test_utils.py | 54 +- tests/v1/tpu/test_basic.py | 90 +- tests/v1/tpu/test_kv_cache_update_kernel.py | 85 +- tests/v1/tpu/test_mha_attn.py | 23 +- tests/v1/tpu/test_multimodal.py | 44 +- tests/v1/tpu/test_pallas.py | 20 +- tests/v1/tpu/test_perf.py | 68 +- tests/v1/tpu/test_sampler.py | 59 +- .../v1/tpu/test_spmd_model_weight_loading.py | 23 +- tests/v1/tpu/test_topk_topp_sampler.py | 85 +- tests/v1/tpu/test_tpu_qkv_linear.py | 13 +- tests/v1/tpu/worker/test_tpu_model_runner.py | 127 +- tests/v1/worker/test_gpu_input_batch.py | 147 ++- tests/v1/worker/test_gpu_model_runner.py | 207 +-- tests/vllm_test_utils/setup.py | 6 +- .../vllm_test_utils/vllm_test_utils/blame.py | 4 +- .../vllm_test_utils/monitor.py | 27 +- tests/weight_loading/test_weight_loading.py | 32 +- tests/worker/conftest.py | 2 +- .../test_encoder_decoder_model_runner.py | 163 ++- tests/worker/test_model_input.py | 105 +- tests/worker/test_model_runner.py | 121 +- tests/worker/test_profile.py | 19 +- tests/worker/test_swap.py | 15 +- vllm/benchmarks/datasets.py | 34 +- vllm/benchmarks/serve.py | 30 +- 594 files changed, 33881 insertions(+), 28948 deletions(-) diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index ec6b20f5e04b..57d1fe4256cb 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """vllm.entrypoints.api_server with some extra logging for testing.""" + from collections.abc import Iterable from typing import Any @@ -17,7 +18,6 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._num_aborts = 0 @@ -47,8 +47,10 @@ def stats() -> Response: engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngineWithStats.from_engine_args(engine_args) vllm.entrypoints.api_server.engine = engine - uvicorn.run(app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE) + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ) diff --git a/tests/async_engine/conftest.py b/tests/async_engine/conftest.py index 375b248ebeda..a6a8b33e19d3 100644 --- a/tests/async_engine/conftest.py +++ b/tests/async_engine/conftest.py @@ -9,4 +9,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 76c94bdf80ca..f6c35a118e66 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -13,13 +13,15 @@ def _query_server(prompt: str, max_tokens: int = 5) -> dict: - response = requests.post("http://localhost:8000/generate", - json={ - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": 0, - "ignore_eos": True - }) + response = requests.post( + "http://localhost:8000/generate", + json={ + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0, + "ignore_eos": True, + }, + ) response.raise_for_status() return response.json() @@ -30,8 +32,9 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture def api_server(distributed_executor_backend: str): - script_path = Path(__file__).parent.joinpath( - "api_server_async_engine.py").absolute() + script_path = ( + Path(__file__).parent.joinpath("api_server_async_engine.py").absolute() + ) commands = [ sys.executable, "-u", @@ -80,8 +83,9 @@ def test_api_server(api_server, distributed_executor_backend: str): for result in pool.map(_query_server, prompts): assert result - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] + num_aborted_requests = requests.get("http://localhost:8000/stats").json()[ + "num_aborted_requests" + ] assert num_aborted_requests == 0 # Try with 100 prompts @@ -101,8 +105,9 @@ def test_api_server(api_server, distributed_executor_backend: str): # give it some times to update the stats time.sleep(1) - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] + num_aborted_requests = requests.get("http://localhost:8000/stats").json()[ + "num_aborted_requests" + ] assert num_aborted_requests > 0 # check that server still runs after cancellations diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 0eb7a6eb52aa..c1ed64abd6e7 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -36,7 +36,6 @@ class MockModelConfig: class MockEngine: - def __init__(self): self.step_calls = 0 self.add_request_calls = 0 @@ -49,8 +48,7 @@ def __init__(self): async def step_async(self, virtual_engine): # PP size is 1, ignore virtual engine self.step_calls += 1 - return [RequestOutput( - request_id=self.request_id)] if self.request_id else [] + return [RequestOutput(request_id=self.request_id)] if self.request_id else [] async def process_model_inputs_async(self, *args, **kwargs): pass @@ -67,7 +65,7 @@ def stop_generating(self): def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 - print(f'Request calls: {self.add_request_calls}') + print(f"Request calls: {self.add_request_calls}") async def add_request_async(self, **kwargs): self.add_request_calls += 1 @@ -142,9 +140,12 @@ def start_engine(): print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", - enforce_eager=True, - num_scheduler_steps=num_scheduler_steps)) + AsyncEngineArgs( + model="facebook/opt-125m", + enforce_eager=True, + num_scheduler_steps=num_scheduler_steps, + ) + ) def uid() -> str: @@ -157,8 +158,9 @@ async def async_engine(): # scoped fixture and monkeypatch is function scoped. previous_value = os.getenv("VLLM_USE_V1", None) os.environ["VLLM_USE_V1"] = "0" - engine = await asyncio.get_event_loop().run_in_executor(executor=None, - func=start_engine) + engine = await asyncio.get_event_loop().run_in_executor( + executor=None, func=start_engine + ) try: yield engine finally: @@ -182,7 +184,6 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") @pytest.mark.parametrize("stop", [None, ["a stop string"]]) async def test_asyncio_run(async_engine, stop): - scheduler_config = await async_engine.get_scheduler_config() num_scheduler_steps = scheduler_config.num_scheduler_steps @@ -196,9 +197,9 @@ async def run(prompt: str): output_count = 0 final_output = None - async for output in async_engine.generate(prompt, - sampling_params, - request_id=uid()): + async for output in async_engine.generate( + prompt, sampling_params, request_id=uid() + ): output_count += 1 final_output = output return final_output, output_count @@ -247,18 +248,19 @@ async def run(prompt: str, kind: RequestOutputKind): output_count = 0 final_output = None - async for output in async_engine.generate(prompt, - params, - request_id=uid()): + async for output in async_engine.generate(prompt, params, request_id=uid()): output_count += 1 final_output = output assert final_output is not None assert final_output.finished - return (final_output.prompt_token_ids, - final_output.outputs[0].token_ids, - final_output.outputs[0].text, output_count) + return ( + final_output.prompt_token_ids, + final_output.outputs[0].token_ids, + final_output.outputs[0].text, + output_count, + ) async def run_deltas(prompt: str): params = copy(sampling_params) @@ -269,9 +271,7 @@ async def run_deltas(prompt: str): output_text = "" output_count = 0 final_output = None - async for output in async_engine.generate(prompt, - params, - request_id=uid()): + async for output in async_engine.generate(prompt, params, request_id=uid()): token_ids = output.outputs[0].token_ids text = output.outputs[0].text final_output = output @@ -298,7 +298,8 @@ async def run_deltas(prompt: str): results = await asyncio.gather( run("common input prompt", RequestOutputKind.CUMULATIVE), run("common input prompt", RequestOutputKind.FINAL_ONLY), - run_deltas("common input prompt")) + run_deltas("common input prompt"), + ) # Make sure outputs are the same prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) @@ -342,9 +343,9 @@ async def test_cancellation(async_engine, stop): i = 0 with pytest.raises(CancelledError): - async for output in async_engine.generate("test2", - sampling_params, - request_id=request_id): + async for output in async_engine.generate( + "test2", sampling_params, request_id=request_id + ): assert not output.finished i += 1 if i == stop_at: @@ -402,8 +403,7 @@ async def test_invalid_argument(async_engine): # Targeting specific DP rank only supported in v1 multi-instance DP with pytest.raises(ValueError): - async for _ in async_engine.generate("test", - sampling_params, - request_id=uid(), - data_parallel_rank=0): + async for _ in async_engine.generate( + "test", sampling_params, request_id=uid(), data_parallel_rank=0 + ): pass diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 1851eeeda790..784d6dbb796d 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -60,7 +60,8 @@ async def test_request_tracker(): stream_5 = tracker.add_request("5") assert tracker.new_requests_event.is_set() tracker.process_request_output( - RequestOutput("2", "output", [], [], [], finished=True)) + RequestOutput("2", "output", [], [], [], finished=True) + ) await tracker.wait_for_new_requests() new, aborted = tracker.get_new_and_aborted_requests() assert not tracker.new_requests_event.is_set() diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 2e103019f7af..c75defeda4da 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,7 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ + import os import weakref from unittest.mock import Mock @@ -46,16 +47,21 @@ def test_vllm_gc_ed(): def _fix_prompt_embed_outputs( - vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, - example_prompts: list[str]) -> list[tuple[list[int], str]]: + vllm_outputs: list[tuple[list[int], str]], + hf_model: HfRunner, + example_prompts: list[str], +) -> list[tuple[list[int], str]]: fixed_vllm_outputs = [] for vllm_output, hf_input, prompt in zip( - vllm_outputs, hf_model.get_inputs(example_prompts), - example_prompts): + vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts + ): hf_input_ids = hf_input["input_ids"].tolist()[0] fixed_vllm_outputs.append( - (hf_input_ids + vllm_output[0][len(hf_input_ids):], - prompt + vllm_output[1])) + ( + hf_input_ids + vllm_output[0][len(hf_input_ids) :], + prompt + vllm_output[1], + ) + ) return fixed_vllm_outputs @@ -73,18 +79,14 @@ def test_models( enforce_eager: bool, enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: + if enable_prompt_embeds and envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend in ("XFORMERS", - "FLASHINFER") and model == "google/gemma-2-2b-it": - pytest.skip( - f"{backend} does not support gemma2 with full context length.") + if backend in ("XFORMERS", "FLASHINFER") and model == "google/gemma-2-2b-it": + pytest.skip(f"{backend} does not support gemma2 with full context length.") with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", backend) @@ -92,30 +94,33 @@ def test_models( # 5042 tokens for gemma2 # gemma2 has alternating sliding window size of 4096 # we need a prompt with more than 4096 tokens to test the sliding window - prompt = "The following numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) - - with VllmRunner(model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7) as vllm_model: + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + ) as vllm_model: if enable_prompt_embeds: - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) + vllm_outputs, hf_model, example_prompts + ) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -127,23 +132,20 @@ def test_models( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, " - "test_suite, extra_env", [ + "model, distributed_executor_backend, attention_backend, test_suite, extra_env", + [ ("distilbert/distilgpt2", "ray", "", "L4", {}), ("distilbert/distilgpt2", "mp", "", "L4", {}), - ("distilbert/distilgpt2", "ray", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), - ("distilbert/distilgpt2", "mp", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), + ("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}), - ]) + ], +) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, @@ -157,20 +159,21 @@ def test_models_distributed( extra_env: dict[str, str], enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: + if enable_prompt_embeds and envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: - if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + if ( + model == "meta-llama/Llama-3.2-1B-Instruct" + and distributed_executor_backend == "ray" + and attention_backend == "" + and test_suite == "L4" + ): # noqa if enable_prompt_embeds: - pytest.skip( - "enable_prompt_embeds does not work with ray compiled dag." - ) + pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -192,30 +195,26 @@ def test_models_distributed( # will hurt multiprocessing backend with fork method # (the default method). with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7, + model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: if enable_prompt_embeds: with hf_runner(model, dtype=dtype) as hf_model: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs, hf_model, example_prompts + ) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -226,27 +225,23 @@ def test_models_distributed( def test_failed_model_execution(vllm_runner, monkeypatch) -> None: - from vllm.envs import VLLM_USE_V1 if not VLLM_USE_V1: pytest.skip("Skipping V0 test, dump input not supported") # Needed to mock an error in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model: if isinstance(vllm_model.model.llm_engine, LLMEngineV1): v1_test_failed_model_execution(vllm_model) def v1_test_failed_model_execution(vllm_model): - engine = vllm_model.model.llm_engine - mocked_execute_model = Mock( - side_effect=RuntimeError("Mocked Critical Error")) - engine.engine_core.engine_core.model_executor.execute_model =\ - mocked_execute_model + mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model with pytest.raises(RuntimeError) as exc_info: prompts = [ diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 4816b76996fc..597155497c6e 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -37,7 +37,7 @@ def use_v0_only(monkeypatch: pytest.MonkeyPatch): all tests in the file. """ with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') + m.setenv("VLLM_USE_V1", "0") yield @@ -49,13 +49,18 @@ def use_v0_only(monkeypatch: pytest.MonkeyPatch): # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) +@pytest.mark.parametrize( + "attention_backend", + [ + pytest.param( + "FLASHINFER", + marks=pytest.mark.skipif( + current_platform.is_rocm(), reason="FLASHINFER isn't supported on ROCm" + ), + ), + "FLASH_ATTN", + ], +) def test_models( hf_runner: HfRunner, vllm_runner: VllmRunner, @@ -83,16 +88,15 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -105,13 +109,18 @@ def test_models( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) +@pytest.mark.parametrize( + "attention_backend", + [ + pytest.param( + "FLASHINFER", + marks=pytest.mark.skipif( + current_platform.is_rocm(), reason="FLASHINFER isn't supported on ROCm" + ), + ), + "FLASH_ATTN", + ], +) def test_models_distributed( hf_runner: HfRunner, vllm_runner: VllmRunner, @@ -123,8 +132,10 @@ def test_models_distributed( ) -> None: with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - if (model == "meta-llama/Llama-3.2-1B-Instruct" - and distributed_executor_backend == "ray"): + if ( + model == "meta-llama/Llama-3.2-1B-Instruct" + and distributed_executor_backend == "ray" + ): # test Ray Compiled Graph m.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") m.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -146,13 +157,13 @@ def test_models_distributed( # fork method (the default method). with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy( example_prompts, @@ -172,8 +183,8 @@ def test_models_distributed( @pytest.mark.parametrize( "kv_cache_dtype,model", - [("fp8_e4m3", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")]) + [("fp8_e4m3", "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")], +) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) @@ -184,8 +195,9 @@ def test_models_distributed( # Due to low-precision numerical divergence, this test is too sensitive to # the async postprocessor @pytest.mark.parametrize("disable_async_output_proc", [True]) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="machete_prepack_B isn't supported on ROCm") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="machete_prepack_B isn't supported on ROCm" +) def test_models_with_fp8_kv_cache( vllm_runner: VllmRunner, example_prompts, @@ -208,28 +220,30 @@ def test_models_with_fp8_kv_cache( max_num_batched_tokens = chunked_prefill_token_size with vllm_runner( - model, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, + model, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, ) as vllm_model: no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) with vllm_runner( - model, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, + model, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, ) as vllm_model: chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) check_logprobs_close( outputs_0_lst=no_chunked_prefill_outputs, @@ -272,14 +286,14 @@ def test_with_prefix_caching( outputs = {} # type: ignore for enable in (True, False): with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - enable_prefix_caching=enable, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + enable_prefix_caching=enable, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, ) as vllm_model: outputs[enable] = [] for prompt in full_prompts: diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 28bfe9e7c802..3c1e01d072b9 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -5,5 +5,6 @@ def test_cpu_offload(): - compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], - ["--cpu-offload-gb", "1"]) + compare_two_settings( + "meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"] + ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 34f9389c82a9..f05dd9244b31 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -23,13 +23,13 @@ def test_python_error(): tensors = [] with allocator.use_memory_pool(): # allocate 70% of the total memory - x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(x) # release the memory allocator.sleep() # allocate more memory than the total memory - y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(y) with pytest.raises(RuntimeError): # when the allocator is woken up, it should raise an error @@ -41,17 +41,17 @@ def test_python_error(): def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) - x = torch.empty(shape, device='cuda') + x = torch.empty(shape, device="cuda") x.zero_() # some tensors from custom memory pool allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): # custom memory pool - y = torch.empty(shape, device='cuda') + y = torch.empty(shape, device="cuda") y.zero_() y += 1 - z = torch.empty(shape, device='cuda') + z = torch.empty(shape, device="cuda") z.zero_() z += 2 @@ -74,16 +74,16 @@ def test_basic_cumem(): def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): - weight = torch.eye(1024, device='cuda') + weight = torch.eye(1024, device="cuda") with allocator.use_memory_pool(tag="discard"): - cache = torch.empty(1024, 1024, device='cuda') + cache = torch.empty(1024, 1024, device="cuda") def model(x): out = x @ weight - cache[:out.size(0)].copy_(out) + cache[: out.size(0)].copy_(out) return out + 1 - x = torch.empty(128, 1024, device='cuda') + x = torch.empty(128, 1024, device="cuda") # warmup model(x) @@ -109,7 +109,7 @@ def model(x): model_graph.replay() # cache content is as expected - assert torch.allclose(x, cache[:x.size(0)]) + assert torch.allclose(x, cache[: x.size(0)]) # output content is as expected assert torch.allclose(y, x + 1) @@ -123,7 +123,8 @@ def model(x): ("meta-llama/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint ("facebook/opt-125m", False), - ]) + ], +) def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 341a39a42b85..cc93e6f12e12 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -7,13 +7,13 @@ Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest tests/basic_correctness/test_preemption.py`. """ + import pytest from prometheus_client import REGISTRY import vllm.envs as envs from vllm import SamplingParams -from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, - ENABLE_ARTIFICIAL_PREEMPT) +from vllm.core.scheduler import ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT from ..models.utils import check_outputs_equal @@ -28,7 +28,7 @@ def use_v0_only(monkeypatch): We should enable this for V1, but VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT, so use VLLM_USE_V1=0 for all tests in the file. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") @pytest.fixture(scope="module", autouse=True) @@ -36,7 +36,8 @@ def check_settings(): assert ENABLE_ARTIFICIAL_PREEMPT is True, ( "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1." "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " - "pytest tests/basic_correctness/test_preemption.py`") + "pytest tests/basic_correctness/test_preemption.py`" + ) @pytest.fixture @@ -72,25 +73,29 @@ def test_chunked_prefill_recompute( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, - max_num_seqs=max_num_seqs, - distributed_executor_backend=distributed_executor_backend, - disable_log_stats=False, + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + distributed_executor_backend=distributed_executor_backend, + disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) + assert ( + vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT + ) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] vllm_output_ids, vllm_output_str = vllm_outputs[i] assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}" + ) assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}" + ) @pytest.mark.parametrize("model", MODELS) @@ -112,16 +117,19 @@ def test_preemption( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - distributed_executor_backend=distributed_executor_backend, + model, + dtype=dtype, + disable_log_stats=False, + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) + assert ( + vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT + ) + total_preemption = vllm_model.model.llm_engine.scheduler[ + 0 + ].num_cumulative_preemption check_outputs_equal( outputs_0_lst=hf_outputs, @@ -130,8 +138,10 @@ def test_preemption( name_1="vllm", ) - assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " - "is not enough KV cache space." in caplog_vllm.text) + assert ( + "is preempted by PreemptionMode.RECOMPUTE mode because there " + "is not enough KV cache space." in caplog_vllm.text + ) # Ensure the count bucket of request-level histogram metrics matches # the number of requests as a simple sanity check to ensure metrics are # generated @@ -162,25 +172,26 @@ def test_preemption_infeasible( prefill_blocks = 2 decode_blocks = max_tokens // BLOCK_SIZE with vllm_runner( - model, - dtype=dtype, - block_size=BLOCK_SIZE, - # Not enough gpu blocks to complete a single sequence. - # preemption should happen, and the sequence should be - # ignored instead of hanging forever. - num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, - max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), - distributed_executor_backend=distributed_executor_backend, + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) req_outputs = vllm_model.model.generate( example_prompts, sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) + assert ( + vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT + ) # Verify the request is ignored and not hang. for req_output in req_outputs: diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py index 2279c846e01c..54075a3a15e6 100644 --- a/tests/benchmarks/test_latency_cli.py +++ b/tests/benchmarks/test_latency_cli.py @@ -10,8 +10,18 @@ @pytest.mark.benchmark def test_bench_latency(): command = [ - "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", - "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "latency", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index bfcf274727e2..48d524ceebd0 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -11,9 +11,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" - ] + args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py index b61e51db4fbe..a579b59e8af4 100644 --- a/tests/benchmarks/test_throughput_cli.py +++ b/tests/benchmarks/test_throughput_cli.py @@ -10,8 +10,18 @@ @pytest.mark.benchmark def test_bench_throughput(): command = [ - "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", - "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "throughput", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/build_cython.py b/tests/build_cython.py index 444434e8f0a7..5a968651d8c2 100644 --- a/tests/build_cython.py +++ b/tests/build_cython.py @@ -28,12 +28,13 @@ "vllm/utils/__init__.py", ] -setup(ext_modules=cythonize(infiles, - annotate=False, - force=True, - compiler_directives={ - 'language_level': "3", - 'infer_types': True - })) +setup( + ext_modules=cythonize( + infiles, + annotate=False, + force=True, + compiler_directives={"language_level": "3", "infer_types": True}, + ) +) # example usage: python3 build_cython.py build_ext --inplace diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ace4d25534cd..0a362de08df3 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -25,20 +25,18 @@ class TestBackend: Inductor config is default-initialized from VllmConfig.CompilationConfig. """ - def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], - None]]): + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) compile_config = get_current_vllm_config().compilation_config self.inductor_config = compile_config.inductor_compile_config - self.inductor_config['force_disable_caches'] = True - self.inductor_config['post_grad_custom_post_pass'] = self.post_pass + self.inductor_config["force_disable_caches"] = True + self.inductor_config["post_grad_custom_post_pass"] = self.post_pass def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, - example_inputs, - config_patches=self.inductor_config) + + return compile_fx(graph, example_inputs, config_patches=self.inductor_config) def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) @@ -56,12 +54,11 @@ def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" assert num_pre > num_post, f"All nodes remain for op {op.name()}" if fully_replaced: - assert num_post == 0, \ - f"Unexpected op {op.name()} in post-pass graph" + assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" def check_after_ops(self, ops: Sequence[OpOverload]): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" - assert num_post > 0, f"Op {op.name()} not found in post-pass graph" \ No newline at end of file + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index efe9c843f144..99dd61254a75 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -36,10 +36,7 @@ def temporary_environ(env_vars): def llm_pair(request): model = request.param - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }): + with temporary_environ({"VLLM_USE_V1": "1", "VLLM_FLASH_ATTN_VERSION": "3"}): full = LLM( model=model, gpu_memory_utilization=0.45, @@ -71,11 +68,14 @@ def llm_pair(request): [ # Model names for the llm_pair fixture "deepseek-ai/DeepSeek-V2-Lite", - "Qwen/Qwen2-1.5B-Instruct" + "Qwen/Qwen2-1.5B-Instruct", ], - indirect=True) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") + indirect=True, +) +@pytest.mark.skipif( + current_platform.get_device_capability() != (9, 0), + reason="Only Hopper GPUs support FA3 and FlashMLA", +) class TestFullCUDAGraph: """ Use a class such that an llm pair is constructed once for all @@ -85,20 +85,22 @@ class TestFullCUDAGraph: meaning there would be multiple LLM instances hogging memory simultaneously. """ - @pytest.mark.parametrize(("batch_size", "max_tokens"), [ - (1, 10), - (7, 10), - (16, 10), - (25, 10), - (32, 10), - (45, 10), - (64, 10), - (123, 10), - (8, 5), - (8, 30), - ]) - def test_full_cudagraph(self, batch_size, max_tokens, - llm_pair: tuple[LLM, LLM]): + @pytest.mark.parametrize( + ("batch_size", "max_tokens"), + [ + (1, 10), + (7, 10), + (16, 10), + (25, 10), + (32, 10), + (45, 10), + (64, 10), + (123, 10), + (8, 5), + (8, 30), + ], + ) + def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]): """ Test various batch sizes and max_tokens to ensure that the full cudagraph compilation works for padded cases too. @@ -107,16 +109,15 @@ def test_full_cudagraph(self, batch_size, max_tokens, piecewise_llm, full_cudagraph_llm = llm_pair prompts = ["Hello, my name is"] * batch_size - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - top_p=0.95) + sampling_params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, top_p=0.95 + ) piecewise_responses = piecewise_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params) # Check that all responses are the same - for piecewise_res, full_res in zip(piecewise_responses, - full_responses): + for piecewise_res, full_res in zip(piecewise_responses, full_responses): assert piecewise_res.outputs[0].text == full_res.outputs[0].text @@ -126,33 +127,44 @@ def test_full_cudagraph(self, batch_size, max_tokens, ("Qwen/Qwen2-1.5B-Instruct", True), # MLA does not support capturing CUDA Graphs with size > max_num_seqs ("deepseek-ai/DeepSeek-V2-Lite", False), - ]) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") + ], +) +@pytest.mark.skipif( + current_platform.get_device_capability() != (9, 0), + reason="Only Hopper GPUs support FA3 and FlashMLA", +) def test_lower_max_num_seqs(model, supported): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }), ExitStack() as stack: + with ( + temporary_environ({"VLLM_USE_V1": "1", "VLLM_FLASH_ATTN_VERSION": "3"}), + ExitStack() as stack, + ): if not supported: stack.enter_context(pytest.raises(RuntimeError)) - llm = LLM(model=model, - max_num_seqs=256, - trust_remote_code=True, - max_model_len=1024, - compilation_config=CompilationConfig( - full_cuda_graph=True, - cudagraph_capture_sizes=[64, 256, 512])) + llm = LLM( + model=model, + max_num_seqs=256, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig( + full_cuda_graph=True, cudagraph_capture_sizes=[64, 256, 512] + ), + ) llm.generate(["Hello, my name is"] * 10) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": - "2" #FA2 not supported with full_cuda_graph - }), pytest.raises(RuntimeError): - LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(full_cuda_graph=True)) + with ( + temporary_environ( + { + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "2", # FA2 not supported with full_cuda_graph + } + ), + pytest.raises(RuntimeError), + ): + LLM( + model="Qwen/Qwen2-1.5B-Instruct", + compilation_config=CompilationConfig(full_cuda_graph=True), + ) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 06ac3527e1fb..ee67d6696f70 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,6 +4,7 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ + import pytest import torch from torch import nn @@ -11,8 +12,12 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op @@ -23,8 +28,9 @@ silly_lib = Library("silly", "FRAGMENT") # noqa -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: global global_counter global_counter += 1 print(f"{global_counter=}") @@ -32,8 +38,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out[0] += 1 -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: return @@ -48,12 +55,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @support_torch_compile class SillyModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -81,28 +83,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_simple_piecewise_compile(use_inductor): assert VLLM_USE_V1 - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - use_inductor=use_inductor, - splitting_ops=["silly.attention"], - cudagraph_copy_inputs=True, - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + use_inductor=use_inductor, + splitting_ops=["silly.attention"], + cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], + ) + ) with set_current_vllm_config(vllm_config): - model = SillyModel(vllm_config=vllm_config, prefix='') + model = SillyModel(vllm_config=vllm_config, prefix="") inputs = torch.randn(100).cuda() - with compilation_counter.expect( + with ( + compilation_counter.expect( num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=5, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=3, # 1 + num_layers num_backend_compilations=3, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - + num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ), + set_forward_context({}, vllm_config=vllm_config), + ): model(inputs) model(torch.randn(2).cuda()) @@ -113,4 +118,4 @@ def test_simple_piecewise_compile(use_inductor): global_counter = 0 output = model(input) assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index b7ed8353b3ce..738cc45da6e0 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ + from dataclasses import dataclass from typing import Any, Optional @@ -18,8 +19,12 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op @@ -27,15 +32,17 @@ silly_lib = Library("silly", "FRAGMENT") # noqa -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: out.copy_(q) out += k out += v -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: return @@ -66,15 +73,14 @@ def compute_hash(self) -> str: factors.append((k, v)) factors.sort() import hashlib - return hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + + return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() def __post_init__(self): assert self.mlp_size >= self.hidden_size class LlamaMLP(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.gate_up_projection = nn.Linear( @@ -89,31 +95,31 @@ def __init__(self, config: LlamaConfig) -> None: ) if config.tractable_init: - nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) - nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) + nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size]) + nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :]) nn.init.eye_(self.down_projection.weight.data) else: - nn.init.xavier_normal_(self.gate_up_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.down_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.gate_up_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.down_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward(self, x): # for tractable_init and positive input, this is # essentially an elementwise-square x = self.gate_up_projection(x) - x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( - x[:, x.size(1) // 2:]) + x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :]) x = self.down_projection(x) return x class LlamaAttention(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.qkv_projection = nn.Linear( @@ -129,21 +135,25 @@ def __init__(self, config: LlamaConfig) -> None: ) if config.tractable_init: - nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * - config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[2 * - config.hidden_size:]) + nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size]) + nn.init.eye_( + self.qkv_projection.weight.data[ + config.hidden_size : 2 * config.hidden_size + ] + ) + nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :]) nn.init.eye_(self.output_projection.weight.data) else: - nn.init.xavier_normal_(self.qkv_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.output_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.qkv_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.output_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward( self, @@ -167,7 +177,6 @@ def forward( class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.self_attention = LlamaAttention(config) @@ -187,7 +196,7 @@ def forward( - if residual is not None, the outputs are: - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 - hidden_states = (residual + 1) ** 2 - """ # noqa + """ # noqa if residual is None: residual = hidden_states hidden_states = hidden_states + 1 @@ -196,8 +205,9 @@ def forward( residual = hidden_states hidden_states = hidden_states + 1 - hidden_states = self.self_attention(positions=positions, - hidden_states=hidden_states) + hidden_states = self.self_attention( + positions=positions, hidden_states=hidden_states + ) hidden_states = hidden_states + residual residual = hidden_states @@ -209,20 +219,22 @@ def forward( @support_torch_compile class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - config: LlamaConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + vllm_config: VllmConfig, + config: LlamaConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.embedding_tokens = nn.Embedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, ) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_layers)]) + [LlamaDecoderLayer(config) for _ in range(config.num_layers)] + ) # this is the initial value of the hidden states self.embedding_tokens.weight.data.fill_(config.init_value) @@ -239,34 +251,39 @@ def forward( return hidden_states -def tractable_computation(input_ids: torch.Tensor, - positions: torch.Tensor, - config: LlamaConfig, - init_value: float = 1.0) -> torch.Tensor: - hidden_states = torch.ones(input_ids.size(0), - config.hidden_size, - device=input_ids.device, - dtype=input_ids.dtype) * init_value +def tractable_computation( + input_ids: torch.Tensor, + positions: torch.Tensor, + config: LlamaConfig, + init_value: float = 1.0, +) -> torch.Tensor: + hidden_states = ( + torch.ones( + input_ids.size(0), + config.hidden_size, + device=input_ids.device, + dtype=input_ids.dtype, + ) + * init_value + ) # first layer residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 # following layers for _ in range(config.num_layers - 1): hidden_states = hidden_states + residual residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 return hidden_states @torch.inference_mode -def run_model(llama_config, - use_compile: bool, - use_inductor: bool, - split_attn: bool = False) -> torch.Tensor: - +def run_model( + llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False +) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, @@ -278,18 +295,22 @@ def run_model(llama_config, compilation_config.splitting_ops = ["silly.attention"] else: compilation_config = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, ) + level=CompilationLevel.NO_COMPILATION, + ) - vllm_config = VllmConfig(compilation_config=compilation_config, - additional_config=llama_config) + vllm_config = VllmConfig( + compilation_config=compilation_config, additional_config=llama_config + ) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda() + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + ) with set_forward_context({}, vllm_config=vllm_config): B = 16 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda() model(input_ids, positions) @@ -302,9 +323,9 @@ def run_model(llama_config, output = output.cpu() if llama_config.tractable_init: - expected_output = tractable_computation(input_ids[:2], - positions[:2], - llama_config).cpu() + expected_output = tractable_computation( + input_ids[:2], positions[:2], llama_config + ).cpu() assert torch.allclose(output, expected_output) else: @@ -315,27 +336,23 @@ def run_model(llama_config, def test_toy_llama(use_inductor: bool): # compare output with and without piecewise compilation - llama_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=12) + llama_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12 + ) - tractable_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=2, - tractable_init=True) + tractable_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True + ) outputs = [] with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): - outputs.append( - run_model(llama_config, use_inductor=False, use_compile=False)) + outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) run_model(tractable_config, use_inductor=False, use_compile=False) if use_inductor: @@ -344,41 +361,41 @@ def test_toy_llama(use_inductor: bool): kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=1, - num_piecewise_capturable_graphs_seen=1, - num_backend_compilations=1, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - **kwargs, + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=1, + num_piecewise_capturable_graphs_seen=1, + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + **kwargs, ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True)) + run_model(llama_config, use_inductor=use_inductor, use_compile=True) + ) run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=2 * llama_config.num_layers + - 1, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=1 + - llama_config.num_layers, # 1 + num_layers - num_backend_compilations=1 + - llama_config.num_layers, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=2 * - (1 + llama_config.num_layers - ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=1 + + llama_config.num_layers, # 1 + num_layers + num_backend_compilations=1 + + llama_config.num_layers, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2 + * ( + 1 + llama_config.num_layers + ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True)) - run_model(tractable_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True) + run_model( + llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True, + ) + ) + run_model( + tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True + ) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) @@ -389,17 +406,15 @@ def benchmark(): from triton.testing import do_bench # similar to llama 3.1-8B - llama_config = LlamaConfig(hidden_size=4096, - mlp_size=14336, - vocab_size=128 * 1024, - num_layers=32) + llama_config = LlamaConfig( + hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32 + ) # a tiny model to measure the overhead # of piecewise cudagraph - llama_config = LlamaConfig(hidden_size=40, - mlp_size=80, - vocab_size=128, - num_layers=2) + llama_config = LlamaConfig( + hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2 + ) cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] @@ -425,12 +440,15 @@ def benchmark(): vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda().to(torch.bfloat16) + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + .to(torch.bfloat16) + ) B = 256 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda().to(torch.bfloat16) graphs = {} @@ -452,21 +470,25 @@ def benchmark(): # and use it later, because it will look up the name `b` in the # enclosing scope, and the value of `b` will always be 256. # it is fine here, because we only use the lambda function once. - runtime = do_bench(lambda: graphs[b][0] # noqa - (input_ids[:b], positions[:b])) # noqa + runtime = do_bench( + lambda: graphs[b][0]( # noqa + input_ids[:b], positions[:b] + ) + ) # noqa piecewise_cudagraph_time[b] = runtime else: runtime = do_bench(lambda: graphs[b][0].replay()) # noqa - eager_runtime = do_bench( - lambda: model(input_ids[:b], positions[:b])) # noqa + eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa full_cudagraph_time[b] = runtime eager_time[b] = eager_runtime # print in tabular format print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") for b in cudagraph_sizes: - print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" - f"\t{piecewise_cudagraph_time[b]:.3f}") + print( + f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" + f"\t{piecewise_cudagraph_time[b]:.3f}" + ) if __name__ == "__main__": diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 62804e721e3d..78ac35a28ba6 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -8,18 +8,30 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AsyncTPPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter, +) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import (compare_two_settings, create_new_process_for_each_test, - multi_gpu_test) +from ..utils import ( + compare_two_settings, + create_new_process_for_each_test, + multi_gpu_test, +) from .backend import TestBackend prompts = [ @@ -31,20 +43,19 @@ class TestMMRSModel(torch.nn.Module): - def __init__(self, hidden_size=16): super().__init__() self.hidden_size = hidden_size - self.gate_proj = torch.nn.Parameter(torch.empty( - (self.hidden_size * 2, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) def forward(self, hidden_states): """ Forward pass implementing the mm + reduce scatter in the FX graph - + """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) @@ -63,13 +74,12 @@ def ops_in_model_after(self): class TestAGMMModel(torch.nn.Module): - def __init__(self, hidden_size=16): super().__init__() self.hidden_size = hidden_size - self.weight = torch.nn.Parameter(torch.empty( - (hidden_size, hidden_size)), - requires_grad=False) + self.weight = torch.nn.Parameter( + torch.empty((hidden_size, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.weight, std=0.02) @@ -97,28 +107,33 @@ def ops_in_model_after(self): @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_async_tp_pass_replace( + test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype +): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + nprocs=nprocs, + ) run_torch_spawn(async_tp_pass_on_test_model, num_processes) -def async_tp_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def async_tp_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -126,13 +141,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -140,29 +157,34 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_async_tp=True, ), ) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_async_tp=True, + ), + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - task="auto", - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42, + ) async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) model = test_model_cls(hidden_size) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype, - requires_grad=False) + hidden_states = torch.randn( + (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False + ) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) @@ -210,12 +232,10 @@ def test_async_tp_pass_correctness( common_args.append("--enforce-eager") compilation_config = { - 'level': 3, - 'compile_sizes': [2, 4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_async_tp': async_tp_enabled - }, + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"enable_async_tp": async_tp_enabled}, } async_tp_env = tp_env = { @@ -240,9 +260,6 @@ def test_async_tp_pass_correctness( "mp", ] - compare_two_settings(model_id, - async_tp_args, - tp_args, - async_tp_env, - tp_env, - method="generate") + compare_two_settings( + model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate" + ) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 1ee9b234d9f4..35d7bccb8b63 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -62,8 +62,12 @@ class TestSetting: TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--task", "embed", "--dtype", "bfloat16", "--max-model-len", - "2048" + "--task", + "embed", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, @@ -92,7 +96,8 @@ class TestSetting: method="generate_with_image", fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, @@ -108,23 +113,28 @@ def test_compile_correctness( method = test_setting.method fullgraph = test_setting.fullgraph if cuda_device_count_stateless() != pp_size * tp_size: - pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got " - f"{cuda_device_count_stateless()}") + pytest.skip( + f"Need exactly {pp_size}*{tp_size} CUDA gpus but got " + f"{cuda_device_count_stateless()}" + ) with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) final_args = [ - "--enforce-eager", *model_args, "-pp", - str(pp_size), "-tp", - str(tp_size) + "--enforce-eager", + *model_args, + "-pp", + str(pp_size), + "-tp", + str(tp_size), ] all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.PIECEWISE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.PIECEWISE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) @@ -135,20 +145,20 @@ def test_compile_correctness( model, all_args, all_envs, - method=method if method != "generate" else "generate_close") + method=method if method != "generate" else "generate_close", + ) all_envs.clear() all_args.clear() for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: # "DYNAMO_ONCE" will always use fullgraph - all_envs[-1][ - "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore + all_envs[-1]["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 0ba59f4b5a05..4266cff7a85f 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -9,11 +9,11 @@ def test_version(): - assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') - assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') + assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev") + assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") def test_use_cudagraphs_dynamic(monkeypatch): @@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch): vllm_config = VllmConfig() assert vllm_config.compilation_config.use_cudagraph - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") vllm_config = VllmConfig() assert not vllm_config.compilation_config.use_cudagraph @@ -34,19 +34,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): assert vllm.envs.VLLM_USE_V1 # spawn means that the counters are in the same process. - monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn") - monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val) compilation_config = { "use_cudagraph": False, # speed things up a bit } with ( - compilation_counter.expect(num_cache_entries_updated=0, - num_compiled_artifacts_saved=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_cache_entries_updated=0, num_compiled_artifacts_saved=0 + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -55,20 +59,23 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): assert vllm.envs.VLLM_USE_V1 # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = { "cudagraph_capture_sizes": [100], "use_cudagraph": enabled, } with ( - compilation_counter.expect( - num_graphs_seen=1, - num_gpu_runner_capture_triggers=1 if enabled else 0, - num_cudagraph_captured=13 if enabled else 0, - ), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_graphs_seen=1, + num_gpu_runner_capture_triggers=1 if enabled else 0, + num_cudagraph_captured=13 if enabled else 0, + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 72f962ed7484..3707cb196eeb 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -20,53 +20,67 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { - "dtype": torch.float16, - }), - ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { - "dtype": torch.float16, - }), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + { + "dtype": torch.float16, + }, + ), + ( + "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", + { + "dtype": torch.float16, + }, + ), ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: if is_quant_method_supported("aqlm"): - TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { - "quantization": "aqlm" - })) + TEST_MODELS.append( + ("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {"quantization": "aqlm"}) + ) # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { - "quantization": "gguf" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"}) + ) if is_quant_method_supported("gptq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { - "quantization": "gptq" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"}) + ) if is_quant_method_supported("gptq_marlin"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { - "quantization": "gptq_marlin" - })) + TEST_MODELS.append( + ( + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + {"quantization": "gptq_marlin"}, + ) + ) if is_quant_method_supported("gptq_marlin_24"): - TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { - "quantization": "gptq_marlin_24" - })) + TEST_MODELS.append( + ( + "alexm-nm/tinyllama-24-marlin24-4bit-g128", + {"quantization": "gptq_marlin_24"}, + ) + ) if is_quant_method_supported("marlin"): TEST_MODELS.append( - ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { - "quantization": "marlin" - })) + ( + "robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", + {"quantization": "marlin"}, + ) + ) if not current_platform.is_rocm() and is_quant_method_supported("awq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { - "quantization": "AWQ" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"}) + ) if keywords is None: return TEST_MODELS @@ -102,22 +116,34 @@ def test_full_graph( "compilation_config, model_info", [ # additional compile sizes, only some of the models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - compile_sizes=[1, 2]), model) + ( + CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]), + model, + ) for model in models_list(all=False) - ] + [ + ] + + [ # RMSNorm + quant fusion, only 8-bit quant models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - pass_config=PassConfig(enable_fusion=True, - enable_noop=True)), model) + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), + model, + ) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) - ] + [ + ] + + [ # Test depyf integration works - (CompilationConfig(level=CompilationLevel.PIECEWISE, - debug_dump_path=tempfile.gettempdir()), - ("facebook/opt-125m", {})), - ]) + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir() + ), + ("facebook/opt-125m", {}), + ), + ], +) # only test some of the models @create_new_process_for_each_test() def test_custom_compile_config( @@ -129,8 +155,11 @@ def test_custom_compile_config( run_model(compilation_config, model, model_kwargs) -def run_model(compile_config: Union[int, CompilationConfig], model: str, - model_kwargs: dict[str, Any]): +def run_model( + compile_config: Union[int, CompilationConfig], + model: str, + model_kwargs: dict[str, Any], +): prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index aade29b99de7..1096d5744dbc 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,8 +8,13 @@ from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import ( + FUSED_OPS, + FusionPass, + QuantKey, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, +) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig @@ -26,7 +31,7 @@ RMS_QUANT_OPS = { "static_fp8": [ torch.ops._C.rms_norm_static_fp8_quant.default, - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ], } @@ -43,25 +48,27 @@ @pytest.mark.parametrize( "model, quant_key", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", - kFp8DynamicTokenSym)]) + [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", kFp8DynamicTokenSym), + ], +) @pytest.mark.parametrize("do_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") -def test_fix_functionalization(model: str, quant_key: QuantKey, - do_fusion: bool): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") +def test_fix_functionalization(model: str, quant_key: QuantKey, do_fusion: bool): torch.set_default_device("cuda") vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + ) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = FusionPass.instance(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass, act_quant_fusion_pass - ] if do_fusion else [noop_pass] + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass] if do_fusion else [noop_pass] + ) func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) @@ -76,14 +83,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # 2 LLM instances. sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_func) + model_runner.model = torch.compile(orig_model, fullgraph=True, backend=backend_func) gen_func = llm.generate(prompts, sampling_params) - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_no_func) + model_runner.model = torch.compile( + orig_model, fullgraph=True, backend=backend_no_func + ) gen_no_func = llm.generate(prompts, sampling_params) @@ -92,19 +97,22 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, # and replaced by fused quantized ops in RMS_QUANT_OPS. - rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] - ] if do_fusion else [RMS_OP] - silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ - quant_key == kFp8StaticTensorSym else [ - SILU_MUL_OP - ] + rms_ops = ( + [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]] + if do_fusion + else [RMS_OP] + ) + silu_mul_ops = ( + [SILU_MUL_QUANT_OP] + if do_fusion and quant_key == kFp8StaticTensorSym + else [SILU_MUL_OP] + ) ops = OPS_IN_MODEL + rms_ops + silu_mul_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501 # make sure the ops were all de-functionalized found = dict() diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a3820e20fd8..399d3045cd87 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -6,14 +6,22 @@ import vllm.envs as envs import vllm.plugins -from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) +from vllm.compilation.fusion import ( + FUSED_OPS, + QUANT_OPS, + FusedRMSQuantKey, + FusionPass, + GroupShape, + QuantKey, +) from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + CUTLASS_FP8_SUPPORTED, + Fp8LinearOp, + maybe_create_device_identity, +) from vllm.platforms import current_platform from .backend import TestBackend @@ -22,18 +30,23 @@ class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + def __init__( + self, + hidden_size: int, + eps: float, + static: bool, + cutlass_fp8_enabled: bool, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + self.key = QuantKey( + dtype=FP8_DTYPE, static=static, group_shape=group_shape, symmetric=True + ) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -52,17 +65,15 @@ def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply(y, - self.w[0], - self.wscale[0], - input_scale=self.scale[0]) + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply(y2, - self.w[1], - self.wscale[1], - input_scale=self.scale[1]) + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -72,7 +83,7 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)] + FUSED_OPS[FusedRMSQuantKey(self.key, True)], ] @@ -81,22 +92,27 @@ def ops_in_model_after(self): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], - reason="Only test on CUDA and ROCm") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): +@pytest.mark.parametrize( + "cutlass_fp8_enabled", [True, False] if CUTLASS_FP8_SUPPORTED else [False] +) +@pytest.mark.skipif( + envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" +) +def test_fusion_rmsnorm_quant( + dtype, hidden_size, num_tokens, eps, static, cutlass_fp8_enabled +): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 492e90f2a75f..26c8f4c70177 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -7,11 +7,19 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass -from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, - ModelConfig, PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -21,7 +29,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -42,7 +49,6 @@ def ops_in_model_after(self): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -64,37 +70,45 @@ def ops_in_model_after(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model", - [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel]) + "test_model", [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel] +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -@pytest.mark.skipif(not find_spec("flashinfer"), - reason="flashinfer is not installed") -@pytest.mark.skipif(not current_platform.is_device_capability(100), - reason="Only test on SM100") -def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +@pytest.mark.skipif(not find_spec("flashinfer"), reason="flashinfer is not installed") +@pytest.mark.skipif( + not current_platform.is_device_capability(100), reason="Only test on SM100" +) +def test_all_reduce_fusion_pass_replace( + test_model: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): num_processes = 2 def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + nprocs=nprocs, + ) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) -def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def all_reduce_fusion_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -102,45 +116,53 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) vllm_config = VllmConfig( - compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - compile_sizes=[2, 4, 8])) + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + compile_sizes=[2, 4, 8], + ) + ) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True) + enable_fi_allreduce_fusion=True + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - task="auto", - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42, + ) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) backend = TestBackend(all_reduce_fusion_pass) model = test_model_cls(hidden_size) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) - residual = torch.randn((batch_size * seq_len, hidden_size), - requires_grad=False) + hidden_states = torch.randn( + (batch_size * seq_len, hidden_size), requires_grad=False + ) + residual = torch.randn((batch_size * seq_len, hidden_size), requires_grad=False) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 70750eb9ac4e..a5a0edd85624 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -21,15 +21,18 @@ @pytest.mark.parametrize( - "model, quant_key", - [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) + "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)] +) @pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) + "use_triton_fa", [True, False] if current_platform.is_rocm() else [False] +) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test CUDA and ROCm" +) +def test_attention_fusion( + example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool +): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset() @@ -55,15 +58,15 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, vllm_config = VllmConfig(compilation_config=compile_config) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) - llm = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.9, - max_model_len=2048) + llm = LLM( + model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.9, + max_model_len=2048, + ) - sampling_params = SamplingParams(temperature=0.0, - max_tokens=10, - top_p=0.95) + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95) unfused_output = llm.generate(prompts, sampling_params) backend_unfused = None # Reset backend to make sure llm gets released @@ -82,17 +85,19 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # so we initialize it during compilation. attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) - llm2 = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.9, - max_model_len=2048) + llm2 = LLM( + model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.9, + max_model_len=2048, + ) # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported( + quant_key.dtype, quant_key.static, quant_key.group_shape + ) for key, layer in compile_config.static_forward_context.items() ] @@ -109,9 +114,9 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, for i in range(len(attn_nodes_pre)): assert attn_nodes_pre[i].kwargs["output_scale"] is None fused = attn_nodes_post[i].kwargs["output_scale"] is not None - assert fused == attn_fusion_supported[i], \ - f"Node {i} {'' if fused else 'not '} expected " \ - f"to have fused output quant" + assert fused == attn_fusion_supported[i], ( + f"Node {i} {'' if fused else 'not '} expected to have fused output quant" + ) # check outputs fused_output = llm2.generate(prompts, sampling_params) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 251cc46e9e98..ac561d2e8f84 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -28,7 +28,6 @@ def test_bad_callable(): # Pass that inherits from InductorPass class ProperPass(InductorPass): - def __call__(self, graph: torch.fx.graph.Graph) -> None: pass @@ -39,8 +38,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ProperPass(), # Can also wrap callables in CallableInductorPass for compliance CallableInductorPass(simple_callable), - CallableInductorPass(simple_callable, - InductorPass.hash_source(__file__)) + CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)), ], ) def test_pass_manager_uuid(callable): @@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.compilation_config.pass_config.enable_fusion = not \ - config2.compilation_config.pass_config.enable_fusion + config2.compilation_config.pass_config.enable_fusion = ( + not config2.compilation_config.pass_config.enable_fusion + ) pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index b56edfc90612..4251ae7a9a37 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -10,14 +10,20 @@ from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -34,16 +40,15 @@ class TestModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size))) + torch.empty((intermediate_size, hidden_size)) + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -51,18 +56,18 @@ def __init__(self, def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -80,7 +85,7 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] def ops_in_model(self): @@ -88,47 +93,45 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.vllm_config = vllm_config - self.gate_proj = torch.nn.Parameter(torch.empty( - (intermediate_size, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size)), requires_grad=False + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=True, use_per_token_if_dynamic=False + ) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. - self.w = torch.rand(hidden_size, - intermediate_size).to(dtype=FP8_DTYPE).t() + self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -140,45 +143,51 @@ def forward(self, hidden_states, residual): # for static input quantization # self.fp8_linear is initialized with use_per_token_if_dynamic=False - fp8_linear_result = self.fp8_linear.apply(norm_output, - self.w, - self.wscale, - input_scale=self.scale.to( - norm_output.device)) + fp8_linear_result = self.fp8_linear.apply( + norm_output, + self.w, + self.wscale, + input_scale=self.scale.to(norm_output.device), + ) return fp8_linear_result, residual_output def ops_in_model_before(self): - ops_to_remove = [torch.ops.vllm.all_reduce.default - ] # Always removed by SP + ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # The following are only removed if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_remove.extend([ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ]) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_remove.extend( + [ + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.static_scaled_fp8_quant.default, + ] + ) return ops_to_remove def ops_in_model_after(self): ops_to_add = [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] # The following is only added if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_add.append( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) return ops_to_add def ops_in_model(self): - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): # If fusion happens, the fused op is the one # we check for (de)functionalization - return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ] # noqa: E501 + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501 else: # If no fusion, the original ops are checked return [ @@ -195,30 +204,47 @@ def ops_in_model(self): @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, - enable_fusion: bool): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_sequence_parallelism_pass( + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model_cls, - batch_size, seq_len, hidden_size, - dtype, enable_fusion), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + test_model_cls, + batch_size, + seq_len, + hidden_size, + dtype, + enable_fusion, + ), + nprocs=nprocs, + ) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) def sequence_parallelism_pass_on_test_model( - local_rank: int, world_size: int, - test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, enable_fusion: bool): + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -226,13 +252,15 @@ def sequence_parallelism_pass_on_test_model( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -240,22 +268,27 @@ def sequence_parallelism_pass_on_test_model( # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_sequence_parallelism=True, - enable_fusion=enable_fusion, - enable_noop=True)) # NoOp needed for fusion + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_sequence_parallelism=True, + enable_fusion=enable_fusion, + enable_noop=True, + ) + ) # NoOp needed for fusion vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - task="auto", - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42, + ) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) @@ -270,12 +303,9 @@ def sequence_parallelism_pass_on_test_model( backend_no_func = TestBackend(*passes_for_backend) backend_func = TestBackend(*passes_for_backend, func_pass) - model = test_model_cls(hidden_size, - hidden_size * 2, - vllm_config=vllm_config) + model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) compiled_model_no_func = torch.compile(model, backend=backend_no_func) @@ -294,8 +324,7 @@ def sequence_parallelism_pass_on_test_model( # check if the functionalization pass is applied for op in model.ops_in_model(): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501 # make sure the ops were all de-functionalized found = dict() diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 5351a3cf35ba..fa2446beb327 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -9,27 +9,28 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + CUTLASS_FP8_SUPPORTED, + Fp8LinearOp, +) from vllm.platforms import current_platform from .backend import TestBackend class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, - **kwargs): + def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = ( + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + ) self.fp8_linear = Fp8LinearOp( cutlass_fp8_supported=cutlass_fp8_enabled, @@ -39,28 +40,27 @@ def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, - self.w, - self.wscale, - input_scale=self.wscale) + x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return x2 @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], - reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): +@pytest.mark.parametrize( + "cutlass_fp8_enabled", [True, False] if CUTLASS_FP8_SUPPORTED else [False] +) +@pytest.mark.skipif( + envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" +) +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True)) + pass_config=PassConfig(enable_fusion=True, enable_noop=True) + ) fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) @@ -76,10 +76,12 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, result2 = model2(x) # Check that it gives the same answer - torch.testing.assert_close(result[0].to(dtype=torch.float16), - result2[0].to(dtype=torch.float16), - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close( + result[0].to(dtype=torch.float16), + result2[0].to(dtype=torch.float16), + atol=1e-3, + rtol=1e-3, + ) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 5e39f6821d16..34db5a999cbd 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -10,7 +10,6 @@ class MyMod(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): if cache is not None: return x + cache @@ -18,12 +17,12 @@ def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): class MyWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable, - compilation_level=CompilationLevel.DYNAMO_ONCE) + super().__init__( + compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE + ) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled @@ -54,10 +53,8 @@ def test_torch_compile_wrapper(): # for new input, dispatch to the compiled code directly new_x = torch.tensor([3]) - assert wrapper(new_x, - None).item() == 6 # dispatch to the first compiled code - assert wrapper( - new_x, cache).item() == 5 # dispatch to the second compiled code + assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code + assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code for wrapper in wrappers: # make sure they have independent compiled codes diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py index 024e81fccc5f..cff998d14817 100644 --- a/tests/config/test_config_generation.py +++ b/tests/config/test_config_generation.py @@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch): """ def create_config(): - engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True) + engine_args = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ) return engine_args.create_engine_config() # Create config with CUDA_VISIBLE_DEVICES set normally @@ -34,5 +35,6 @@ def create_config(): empty_config_dict.pop("instance_id", None) assert deep_compare(normal_config_dict, empty_config_dict), ( - "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" - " should be equivalent") + 'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""' + " should be equivalent" + ) diff --git a/tests/config/test_mp_reducer.py b/tests/config/test_mp_reducer.py index ee351cbfa7c1..5a2e784ac7f4 100644 --- a/tests/config/test_mp_reducer.py +++ b/tests/config/test_mp_reducer.py @@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch): """ # Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") # Ensure transformers_modules is not in sys.modules - if 'transformers_modules' in sys.modules: - del sys.modules['transformers_modules'] + if "transformers_modules" in sys.modules: + del sys.modules["transformers_modules"] - with patch('multiprocessing.reducer.register') as mock_register: + with patch("multiprocessing.reducer.register") as mock_register: engine_args = AsyncEngineArgs( model="facebook/opt-125m", max_model_len=32, @@ -37,7 +37,8 @@ def test_mp_reducer(monkeypatch): ) assert mock_register.called, ( - "multiprocessing.reducer.register should have been called") + "multiprocessing.reducer.register should have been called" + ) vllm_config_registered = False for call_args in mock_register.call_args_list: @@ -46,8 +47,7 @@ def test_mp_reducer(monkeypatch): vllm_config_registered = True reducer_func = call_args[0][1] - assert callable( - reducer_func), "Reducer function should be callable" + assert callable(reducer_func), "Reducer function should be callable" break assert vllm_config_registered, ( diff --git a/tests/conftest.py b/tests/conftest.py index f3524d1fe2a6..2ebcae46f656 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,23 +13,33 @@ import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BatchEncoding, BatchFeature) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + BatchFeature, +) from transformers.models.auto.auto_factory import _BaseAutoModelClass -from tests.models.utils import (TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs) +from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import TaskOption, _get_and_verify_dtype from vllm.connections import global_http_connection -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.inputs import ( + ExplicitEncoderDecoderPrompt, + TextPrompt, + to_enc_dec_tuple_list, + zip_enc_dec_prompts, +) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -63,12 +73,13 @@ class ImageAssetPrompts(TypedDict): class ImageTestAssets(list[ImageAsset]): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ @@ -85,11 +96,12 @@ class VideoAssetPrompts(TypedDict): class VideoTestAssets(list[VideoAsset]): - def __init__(self) -> None: - super().__init__([ - VideoAsset("baby_reading"), - ]) + super().__init__( + [ + VideoAsset("baby_reading"), + ] + ) def prompts(self, prompts: VideoAssetPrompts) -> list[str]: return [prompts["baby_reading"]] @@ -101,12 +113,13 @@ class AudioAssetPrompts(TypedDict): class AudioTestAssets(list[AudioAsset]): - def __init__(self) -> None: - super().__init__([ - AudioAsset("mary_had_lamb"), - AudioAsset("winning_call"), - ]) + super().__init__( + [ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ] + ) def prompts(self, prompts: AudioAssetPrompts) -> list[str]: return [prompts["mary_had_lamb"], prompts["winning_call"]] @@ -151,11 +164,11 @@ def run_with_both_engines(request, monkeypatch): if use_v1: if skip_v1: pytest.skip("Skipping test on vllm V1") - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") else: if skip_v0: pytest.skip("Skipping test on vllm V0") - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") yield @@ -221,15 +234,17 @@ def example_system_message() -> str: class DecoderPromptType(Enum): """For encoder/decoder models only.""" + CUSTOM = 1 NONE = 2 EMPTY_STR = 3 @pytest.fixture -def example_encoder_decoder_prompts( -) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]: - ''' +def example_encoder_decoder_prompts() -> dict[ + DecoderPromptType, list[ExplicitEncoderDecoderPrompt] +]: + """ Returns an encoder prompt list and a decoder prompt list, wherein each pair of same-index entries in both lists corresponds to an (encoder prompt, decoder prompt) tuple. @@ -238,7 +253,7 @@ def example_encoder_decoder_prompts( * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) - ''' + """ encoder_prompts = [] for filename in _TEST_PROMPTS: @@ -250,12 +265,15 @@ def example_encoder_decoder_prompts( # NONE decoder prompt type return { - DecoderPromptType.NONE: - zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts), - DecoderPromptType.EMPTY_STR: - zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts), - DecoderPromptType.CUSTOM: - zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts), + DecoderPromptType.NONE: zip_enc_dec_prompts( + encoder_prompts, none_decoder_prompts + ), + DecoderPromptType.EMPTY_STR: zip_enc_dec_prompts( + encoder_prompts, empty_str_decoder_prompts + ), + DecoderPromptType.CUSTOM: zip_enc_dec_prompts( + encoder_prompts, custom_decoder_prompts + ), } @@ -287,15 +305,13 @@ def audio_assets() -> AudioTestAssets: class HfRunner: - def get_default_device(self): from vllm.platforms import current_platform - return ("cpu" - if current_platform.is_cpu() else current_platform.device_type) + return "cpu" if current_platform.is_cpu() else current_platform.device_type def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: - if x is None or isinstance(x, (bool, )): + if x is None or isinstance(x, (bool,)): return x if device is None: @@ -367,14 +383,15 @@ def __init__( ) # in case some unquantized custom models are not in same dtype - if (getattr(model, "quantization_method", None) is None - and any(p.dtype != self.dtype - for p in model.parameters())): + if getattr(model, "quantization_method", None) is None and any( + p.dtype != self.dtype for p in model.parameters() + ): model = model.to(dtype=self.dtype) - if (getattr(model, "quantization_method", None) != "bitsandbytes" - and len({p.device - for p in model.parameters()}) < 2): + if ( + getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device for p in model.parameters()}) < 2 + ): model = model.to(device=self.device) self.model = model @@ -389,6 +406,7 @@ def __init__( # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, @@ -469,10 +487,9 @@ def generate( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: @@ -499,16 +516,17 @@ def generate_greedy( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - images=images, - videos=videos, - audios=audios, - **kwargs) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_beam_search( self, @@ -519,21 +537,22 @@ def generate_beam_search( videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - num_beams=beam_width, - num_return_sequences=beam_width, - images=images, - videos=videos, - audios=audios) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios, + ) for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): output_ids[j] = [ - x for x in output_ids[j] - if x != self.tokenizer.pad_token_id + x for x in output_ids[j] if x != self.tokenizer.pad_token_id ] outputs[i] = (output_ids, output_str) return outputs @@ -547,10 +566,9 @@ def generate_greedy_logprobs( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[list[torch.Tensor]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: @@ -563,8 +581,7 @@ def generate_greedy_logprobs( return_dict_in_generate=True, **kwargs, ) - seq_logprobs = self._hidden_states_to_seq_logprobs( - output.hidden_states) + seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states) all_logprobs.append(seq_logprobs) return all_logprobs @@ -628,10 +645,9 @@ def generate_greedy_logprobs_limit( videos: Optional[PromptVideoInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[dict[int, float]]] = [] all_output_ids: list[list[int]] = [] @@ -651,8 +667,7 @@ def generate_greedy_logprobs_limit( ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.hidden_states, - num_logprobs) + ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] @@ -662,8 +677,10 @@ def generate_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] def generate_encoder_decoder_greedy_logprobs_limit( self, @@ -673,16 +690,17 @@ def generate_encoder_decoder_greedy_logprobs_limit( images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: - ''' + """ Greedy logprobs generation for vLLM encoder/decoder models - ''' + """ all_logprobs: list[list[dict[int, float]]] = [] all_output_ids: list[list[int]] = [] all_output_strs: list[str] = [] for i, (encoder_prompt, decoder_prompt) in enumerate( - to_enc_dec_tuple_list(encoder_decoder_prompts)): + to_enc_dec_tuple_list(encoder_decoder_prompts) + ): processor_kwargs: dict[str, Any] = { "text": encoder_prompt, "return_tensors": "pt", @@ -696,8 +714,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( if decoder_prompt is None: decoder_input_ids = None else: - decoder_inputs = self.tokenizer(decoder_prompt, - return_tensors="pt") + decoder_inputs = self.tokenizer(decoder_prompt, return_tensors="pt") decoder_input_ids = self.wrap_device(decoder_inputs.input_ids) output = self.model.generate( @@ -714,8 +731,9 @@ def generate_encoder_decoder_greedy_logprobs_limit( ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.decoder_hidden_states, - num_logprobs) + ) = self._hidden_states_to_logprobs( + output.decoder_hidden_states, num_logprobs + ) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] @@ -724,19 +742,16 @@ def generate_encoder_decoder_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] - def encode(self, prompts: list[str], *args, - **kwargs) -> list[list[torch.Tensor]]: + def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: return self.model.encode(prompts, *args, **kwargs) - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - return self.model.predict(prompts, - *args, - convert_to_tensor=True, - **kwargs) + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs) def __enter__(self): return self @@ -809,12 +824,12 @@ def get_inputs( videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> list[TextPrompt]: - - if any(x is not None and len(x) != len(prompts) - for x in [images, videos, audios]): + if any( + x is not None and len(x) != len(prompts) for x in [images, videos, audios] + ): raise ValueError( - "All non-None multimodal inputs must have the same length as " - "prompts") + "All non-None multimodal inputs must have the same length as prompts" + ) inputs = [] for i, prompt in enumerate(prompts): @@ -849,14 +864,11 @@ def generate( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.model.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params, **kwargs + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: @@ -883,8 +895,9 @@ def _final_steps_generate_w_logprobs( output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs, - req_output.prompt_logprobs)) + outputs.append( + (output_ids, output_str, output_logprobs, req_output.prompt_logprobs) + ) return outputs def generate_w_logprobs( @@ -895,43 +908,45 @@ def generate_w_logprobs( audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - req_outputs = self.model.generate(inputs, - sampling_params=sampling_params, - **kwargs) - - toks_str_logsprobs_prompt_logprobs = ( - self._final_steps_generate_w_logprobs(req_outputs)) + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params, **kwargs + ) + + toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs( + req_outputs + ) # Omit prompt logprobs if not required by sampling params - return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.prompt_logprobs is None else - toks_str_logsprobs_prompt_logprobs) + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - ''' + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + """ Logprobs generation for vLLM encoder/decoder models - ''' + """ assert sampling_params.logprobs is not None - req_outputs = self.model.generate(encoder_decoder_prompts, - sampling_params=sampling_params) - toks_str_logsprobs_prompt_logprobs = ( - self._final_steps_generate_w_logprobs(req_outputs)) + req_outputs = self.model.generate( + encoder_decoder_prompts, sampling_params=sampling_params + ) + toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs( + req_outputs + ) # Omit prompt logprobs if not required by sampling params - return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.prompt_logprobs is None else - toks_str_logsprobs_prompt_logprobs) + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) def generate_greedy( self, @@ -943,14 +958,15 @@ def generate_greedy( **kwargs: Any, ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, - greedy_params, - images=images, - videos=videos, - audios=audios, - **kwargs) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + outputs = self.generate( + prompts, + greedy_params, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, @@ -964,22 +980,24 @@ def generate_greedy_logprobs( stop_token_ids: Optional[list[int]] = None, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, - stop=stop) + stop=stop, + ) - return self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos, - **kwargs) + return self.generate_w_logprobs( + prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos, + **kwargs, + ) def generate_encoder_decoder_greedy_logprobs( self, @@ -988,8 +1006,7 @@ def generate_encoder_decoder_greedy_logprobs( num_logprobs: int, num_prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -997,12 +1014,13 @@ def generate_encoder_decoder_greedy_logprobs( prompt_logprobs=(num_prompt_logprobs), skip_special_tokens=skip_special_tokens, ) - ''' + """ Greedy logprobs generation for vLLM encoder/decoder models - ''' + """ return self.generate_encoder_decoder_w_logprobs( - encoder_decoder_prompts, greedy_logprobs_params) + encoder_decoder_prompts, greedy_logprobs_params + ) def generate_beam_search( self, @@ -1013,14 +1031,11 @@ def generate_beam_search( videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) outputs = self.model.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + inputs, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens) + ) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] @@ -1032,17 +1047,16 @@ def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def embed(self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - *args, - **kwargs) -> list[list[float]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + def embed( + self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs, + ) -> list[list[float]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] @@ -1081,6 +1095,7 @@ def vllm_runner(): @pytest.fixture() def temporary_enable_log_propagate(): import logging + logger = logging.getLogger("vllm") logger.propagate = True yield @@ -1100,6 +1115,7 @@ def num_gpus_available(): in current process.""" from vllm.platforms import current_platform + return current_platform.device_count() @@ -1113,12 +1129,11 @@ def num_gpus_available(): def dummy_opt_path(): json_path = os.path.join(_dummy_opt_path, "config.json") if not os.path.exists(_dummy_opt_path): - snapshot_download(repo_id="facebook/opt-125m", - local_dir=_dummy_opt_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="facebook/opt-125m", + local_dir=_dummy_opt_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1132,12 +1147,11 @@ def dummy_opt_path(): def dummy_llava_path(): json_path = os.path.join(_dummy_llava_path, "config.json") if not os.path.exists(_dummy_llava_path): - snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf", - local_dir=_dummy_llava_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="llava-hf/llava-1.5-7b-hf", + local_dir=_dummy_llava_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1151,12 +1165,11 @@ def dummy_llava_path(): def dummy_gemma2_embedding_path(): json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") if not os.path.exists(_dummy_gemma2_embedding_path): - snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", - local_dir=_dummy_gemma2_embedding_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1169,10 +1182,9 @@ def dummy_gemma2_embedding_path(): # Add the flag `--optional` to allow run tests # that are marked with @pytest.mark.optional def pytest_addoption(parser): - parser.addoption("--optional", - action="store_true", - default=False, - help="run optional test") + parser.addoption( + "--optional", action="store_true", default=False, help="run optional test" + ) def pytest_collection_modifyitems(config, items): diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index e2c6c66b259c..c6e9bf88e71e 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -12,21 +12,26 @@ @pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed) +def baseline_llm_generator( + common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, seed +): + return create_llm_generator( + common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, seed + ) @pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) +def test_llm_generator( + common_llm_kwargs, per_test_common_llm_kwargs, test_llm_kwargs, seed +): + return create_llm_generator( + common_llm_kwargs, per_test_common_llm_kwargs, test_llm_kwargs, seed + ) -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): +def create_llm_generator( + common_llm_kwargs, per_test_common_llm_kwargs, distinct_llm_kwargs, seed +): kwargs = { **common_llm_kwargs, **per_test_common_llm_kwargs, @@ -47,11 +52,12 @@ def generator_inner(): del llm -def get_text_from_llm_generator(llm_generator: Iterable[LLM], - prompts, - sampling_params, - llm_cb: Optional[Callable[[LLM], - None]] = None): +def get_text_from_llm_generator( + llm_generator: Iterable[LLM], + prompts, + sampling_params, + llm_cb: Optional[Callable[[LLM], None]] = None, +): for llm in llm_generator: if llm_cb: llm_cb(llm) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 93222b564ebe..8c25c06e78da 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -12,28 +12,28 @@ @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + # skip cuda graph creation for fast test. + "enforce_eager": True, + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) +@pytest.mark.parametrize( + "test_llm_kwargs", [{"preemption_mode": "swap"}, {"preemption_mode": "recompute"}] +) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_block_manager_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): +def test_block_manager_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size +): """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit @@ -47,8 +47,8 @@ def test_block_manager_with_preemption(baseline_llm_generator, KV mapping has time to build up error. NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we keep this test. """ output_len = 1024 @@ -74,13 +74,14 @@ def test_block_manager_with_preemption(baseline_llm_generator, ) baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids @@ -88,38 +89,43 @@ def test_block_manager_with_preemption(baseline_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # Our prompts will generate 128 tokens; since the prompts themselves are - # small, we don't need much KV space beyond 128. - "max_model_len": 160, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - }]) + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + # Our prompts will generate 128 tokens; since the prompts themselves are + # small, we don't need much KV space beyond 128. + "max_model_len": 160, + # skip cuda graph creation for fast test. + "enforce_eager": True, + } + ], +) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ { "block_size": 16, - # Allow only 2 sequences of ~128 tokens in worst case. # Note 8 = 128/block_size "num_gpu_blocks_override": 2 * (8 + 1), }, { "block_size": 8, - # Allow only 2 sequences of ~128 tokens in worst case. # Note 16 = 128/block_size "num_gpu_blocks_override": 2 * (16 + 2), + }, + ], +) +@pytest.mark.parametrize( + "baseline_llm_kwargs", + [ + { + "num_lookahead_slots": 0, } - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "num_lookahead_slots": 0, -}]) + ], +) @pytest.mark.parametrize( "test_llm_kwargs", [ @@ -132,13 +138,14 @@ def test_block_manager_with_preemption(baseline_llm_generator, { "num_lookahead_slots": 10, "preemption_mode": "recompute", - } - ]) + }, + ], +) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size): +def test_lookahead_greedy_equality_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size +): """Verify vLLM produces the same output with greedy sampling, when lookahead scheduling is used vs. not. @@ -167,16 +174,17 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids without lookahead scheduling') + print("Getting token ids without lookahead scheduling") baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - print('Getting token ids with lookahead scheduling') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + print("Getting token ids with lookahead scheduling") + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids @@ -188,42 +196,55 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, { # Use a small model for a fast test. "model": "facebook/opt-125m", - # skip cuda graph creation for fast test. "enforce_eager": True, "enable_chunked_prefill": True, }, - ]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", - [{ - "block_size": 16, - "max_num_batched_tokens": 2, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 3, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 256, - "max_num_seqs": 10, - }]) -@pytest.mark.parametrize("baseline_llm_kwargs", [ - {}, -]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "num_lookahead_slots": 0, - }, - { - "num_lookahead_slots": 5, - }, -]) + ], +) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "block_size": 16, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, + { + "block_size": 16, + "max_num_batched_tokens": 3, + "max_num_seqs": 2, + }, + { + "block_size": 16, + "max_num_batched_tokens": 256, + "max_num_seqs": 10, + }, + ], +) +@pytest.mark.parametrize( + "baseline_llm_kwargs", + [ + {}, + ], +) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "num_lookahead_slots": 0, + }, + { + "num_lookahead_slots": 5, + }, + ], +) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify that chunked prefill works with SelfAttnBlockSpaceManager, +def test_chunked_prefill_block_manager( + baseline_llm_generator, test_llm_generator, batch_size +): + """Verify that chunked prefill works with SelfAttnBlockSpaceManager, with and without lookahead scheduling. """ output_len = 32 @@ -245,16 +266,17 @@ def test_chunked_prefill_block_manager(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids with BlockManager') + print("Getting token ids with BlockManager") baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - print('Getting token ids with BlockManager, with lookahead slots.') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + print("Getting token ids with BlockManager, with lookahead slots.") + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids @@ -262,31 +284,30 @@ def test_chunked_prefill_block_manager(baseline_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - - # Enable prefill cache - "enable_prefix_caching": True, - }]) + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + # skip cuda graph creation for fast test. + "enforce_eager": True, + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + # Enable prefill cache + "enable_prefix_caching": True, + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) +@pytest.mark.parametrize( + "test_llm_kwargs", [{"preemption_mode": "swap"}, {"preemption_mode": "recompute"}] +) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) def test_block_manager_prefix_caching_enabled_with_preemption( - baseline_llm_generator, test_llm_generator, batch_size): + baseline_llm_generator, test_llm_generator, batch_size +): """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit @@ -300,8 +321,8 @@ def test_block_manager_prefix_caching_enabled_with_preemption( KV mapping has time to build up error. NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we keep this test. """ output_len = 1024 @@ -326,16 +347,17 @@ def test_block_manager_prefix_caching_enabled_with_preemption( temperature=temperature, ) - print('Getting token ids from block manager') + print("Getting token ids from block manager") baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - print('Getting token ids from block manager, with preemption') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + print("Getting token ids from block manager, with preemption") + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids @@ -343,32 +365,32 @@ def test_block_manager_prefix_caching_enabled_with_preemption( @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + # skip cuda graph creation for fast test. + "enforce_eager": True, + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, - "preemption_mode": "swap" -}, { - "enable_prefix_caching": True, - "preemption_mode": "recompute" -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{"enable_prefix_caching": False}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + {"enable_prefix_caching": True, "preemption_mode": "swap"}, + {"enable_prefix_caching": True, "preemption_mode": "recompute"}, + ], +) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): +def test_auto_prefix_caching_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size +): """Verify block manager v2 with auto prefix caching enabled produces same outputs as auto prefix caching disabled, even when there is preemption. @@ -400,16 +422,17 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids with APC disabled') + print("Getting token ids with APC disabled") baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + print("Getting token ids with APC enabled") + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids @@ -417,28 +440,33 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # we keep the blocks small, so that hit eviction quickly - "max_model_len": 48, - "block_size": 16, - "num_gpu_blocks_override": 3, - }]) + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + # skip cuda graph creation for fast test. + "enforce_eager": True, + # we keep the blocks small, so that hit eviction quickly + "max_model_len": 48, + "block_size": 16, + "num_gpu_blocks_override": 3, + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{"enable_prefix_caching": False}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "enable_prefix_caching": True, + } + ], +) @pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, - test_llm_generator): +def test_auto_prefix_caching_after_eviction_start( + baseline_llm_generator, test_llm_generator +): """Verify block manager v2 with auto prefix caching could works normal even when eviction started. With APC enabled, all blocks are held by native block at the beginning. @@ -455,7 +483,7 @@ def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, "You are a helpful assistant. Please answer truthfully and write out " "your thinking step by step to be sure you get the right answer. You " "are helpful and harmless and you follow ethical guidelines. " - "who are you?" + "who are you?", ] sampling_params = SamplingParams( @@ -464,16 +492,17 @@ def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids with APC disabled') + print("Getting token ids with APC disabled") baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + baseline_llm_generator, prompts, sampling_params + ) - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) + print("Getting token ids with APC enabled") + test_token_ids = get_token_ids_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, test_token_ids): assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 4d67eea2264b..eed7c3387e1f 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -18,23 +18,26 @@ @pytest.mark.parametrize( "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) + [ + { + "model": MODEL, + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, - batch_size, seed, backend, monkeypatch): +def test_sliding_window_retrieval( + baseline_llm_generator, test_llm_generator, batch_size, seed, backend, monkeypatch +): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -58,16 +61,16 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, prompts, answer, indices = prep_prompts(batch_size) - baseline_texts = get_text_from_llm_generator(baseline_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) + baseline_texts = get_text_from_llm_generator( + baseline_llm_generator, prompts, sampling_params, llm_cb=check_window(prompts) + ) check_answers(indices, answer, baseline_texts) - print('Getting token ids from block manager v2') - test_texts = get_text_from_llm_generator(test_llm_generator, prompts, - sampling_params) + print("Getting token ids from block manager v2") + test_texts = get_text_from_llm_generator( + test_llm_generator, prompts, sampling_params + ) check_answers(indices, answer, test_texts) cmp = [ @@ -84,21 +87,24 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) + [ + { + "model": MODEL, + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + } + ], +) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, - backend, monkeypatch): +def test_sliding_window_chunked_prefill( + test_llm_generator, batch_size, seed, backend, monkeypatch +): """ This is similar to test_sliding_window_retrieval, however, it doesn't compare against the v1 block manager since v1 doesn't support @@ -123,10 +129,9 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, # We don't compare with the baseline model here, since the results # slightly different due to different tailing in attention. - test_texts = get_text_from_llm_generator(test_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) + test_texts = get_text_from_llm_generator( + test_llm_generator, prompts, sampling_params, llm_cb=check_window(prompts) + ) check_answers(indices, answer, test_texts) @@ -148,8 +153,10 @@ def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): for _ in range(batch_size): idx = random.randint(30, 90) indices.append(idx) - prompt = "```python\n# We set a number of variables, " + \ - f"x{idx} will be important later\n" + prompt = ( + "```python\n# We set a number of variables, " + + f"x{idx} will be important later\n" + ) ln = random.randint(*ln_range) for k in range(30, ln): v = random.randint(10, 99) @@ -162,10 +169,9 @@ def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): return prompts, answer, indices -def check_answers(indices: list[int], - answer: list[int], - outputs: list[str], - accept_rate: float = 0.7): +def check_answers( + indices: list[int], answer: list[int], outputs: list[str], accept_rate: float = 0.7 +): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 @@ -178,12 +184,12 @@ def check_answers(indices: list[int], def check_window(prompts: list[str]): - def inner(llm: LLM): sliding_window = llm.llm_engine.model_config.get_sliding_window() assert sliding_window and sliding_window > 0 assert any( len(llm.get_tokenizer().tokenize(prompt)) > sliding_window - for prompt in prompts) + for prompt in prompts + ) return inner diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py index 9eed264fd7d4..b3344bdd65c8 100644 --- a/tests/core/block/test_block_manager.py +++ b/tests/core/block/test_block_manager.py @@ -3,23 +3,29 @@ import pytest -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.core.block.utils import ( + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA, +) from vllm.core.block_manager import SelfAttnBlockSpaceManager from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import (create_dummy_prompt, create_seq_group, - create_seq_group_encoder_decoder) +from ..utils import ( + create_dummy_prompt, + create_seq_group, + create_seq_group_encoder_decoder, +) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) @pytest.mark.parametrize("num_seqs_per_group", [1, 4]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): +def test_can_allocate_seq_group( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float +): block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, @@ -62,10 +68,9 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, @pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) @pytest.mark.parametrize("num_seqs_per_group", [1, 4]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group_encoder_decoder(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): +def test_can_allocate_seq_group_encoder_decoder( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float +): block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, @@ -82,7 +87,8 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_output_blocks = num_output_blocks_per_seq for bdx, num_prompt_blocks in enumerate( - range(1, num_gpu_blocks - num_output_blocks)): + range(1, num_gpu_blocks - num_output_blocks) + ): num_cross_blocks_per_seq = num_prompt_blocks seq_group = create_seq_group_encoder_decoder( @@ -91,15 +97,16 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, block_size * num_output_blocks_per_seq for _ in range(num_seqs_per_group) ], - request_id=str(bdx)) + request_id=str(bdx), + ) assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks can_allocate_result = block_manager.can_allocate(seq_group) - num_required_blocks = num_prompt_blocks + \ - num_output_blocks + \ - num_cross_blocks_per_seq + num_required_blocks = ( + num_prompt_blocks + num_output_blocks + num_cross_blocks_per_seq + ) if num_gpu_blocks - num_required_blocks < num_watermark_blocks: assert can_allocate_result == AllocStatus.NEVER @@ -113,11 +120,10 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, @pytest.mark.parametrize("num_gpu_blocks", [16]) @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): - ''' +def test_can_allocate_encoder_decoder_fails_with_swa( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float +): + """ SWA short for Sliding Window Attention. At time of writing block manager does not support SWA. @@ -135,7 +141,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, The setup for this test is stripped down version of test_can_allocate_seq_group_encoder_decoder() - ''' + """ with pytest.raises((NotImplementedError, AssertionError)) as exc_info: block_manager = SelfAttnBlockSpaceManager( @@ -143,7 +149,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, watermark=watermark, - sliding_window=5 # SWA + sliding_window=5, # SWA ) num_output_blocks_per_seq = 1 @@ -155,7 +161,8 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, block_size * num_output_blocks_per_seq for _ in range(num_seqs_per_group) ], - request_id="0") + request_id="0", + ) assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks block_manager.can_allocate(seq_group) @@ -177,15 +184,14 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) def test_can_allocate_encoder_decoder_fails_with_prefix_cache( - block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, - watermark: float): - + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float +): block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, watermark=watermark, - enable_caching=True # Prefix cache + enable_caching=True, # Prefix cache ) num_output_blocks_per_seq = 1 @@ -194,10 +200,10 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( seq_group = create_seq_group_encoder_decoder( seq_prompt_len=block_size * num_prompt_blocks, seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) + block_size * num_output_blocks_per_seq for _ in range(num_seqs_per_group) ], - request_id="0") + request_id="0", + ) assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks @@ -212,8 +218,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) @pytest.mark.parametrize("num_lookahead_slots", [0, 10]) -def test_append_slots(block_size, prompt_len, num_slots_to_append, - num_lookahead_slots): +def test_append_slots(block_size, prompt_len, num_slots_to_append, num_lookahead_slots): """Verify append_slots consumes the correct number of blocks from the block table. """ @@ -247,18 +252,19 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, # Append slots for new tokens and lookahead slots. free_blocks_before_append = block_manager.get_num_free_gpu_blocks() block_manager.append_slots(seq, num_lookahead_slots) - num_consumed_blocks = (free_blocks_before_append - - block_manager.get_num_free_gpu_blocks()) + num_consumed_blocks = ( + free_blocks_before_append - block_manager.get_num_free_gpu_blocks() + ) # Expect consumed blocks to be new blocks required to support the new slots. expected_consumed_blocks = len( list( chunk_list( - list( - range(prompt_len + num_slots_to_append + - num_lookahead_slots)), - block_size))) - len( - list(chunk_list(list(range(prompt_len)), block_size))) + list(range(prompt_len + num_slots_to_append + num_lookahead_slots)), + block_size, + ) + ) + ) - len(list(chunk_list(list(range(prompt_len)), block_size))) assert num_consumed_blocks == expected_consumed_blocks @@ -267,16 +273,19 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, @pytest.mark.parametrize("num_gpu_blocks", [4]) @pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) @pytest.mark.parametrize("enable_caching", [False, True]) -def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, - enable_caching): +def test_swap( + block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, enable_caching +): """Verify blocks number on src/desc device is correct after swapping in/out - sequence group (not missing or extra blocks). + sequence group (not missing or extra blocks). """ - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager( + block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching, + ) prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) @@ -319,19 +328,21 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, @pytest.mark.parametrize("num_gpu_blocks", [4]) @pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10]) @pytest.mark.parametrize("enable_caching", [True, False]) -def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, - enable_caching): - """ Verify the block manager can correctly determine if a sequence group - can be swapped in/out. +def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, enable_caching): + """Verify the block manager can correctly determine if a sequence group + can be swapped in/out. """ num_cpu_blocks = num_gpu_blocks - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager( + block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching, + ) prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) + "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1 + ) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) prompt.status = SequenceStatus.RUNNING @@ -352,11 +363,14 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, # At this moment, we still have enough free blocks to swap in the seq group. if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) == AllocStatus.OK + ) else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) + == AllocStatus.NEVER + ) # During Swapped out, 2 cached blocks were evicted from the GPU, # so the prompt1 can't be swapped in @@ -364,17 +378,22 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, prompt2, seq_group2 = create_dummy_prompt( "2", prompt_length=prompt2_len, - prompt_tokens=[10000 + i for i in range(prompt2_len)]) + prompt_tokens=[10000 + i for i in range(prompt2_len)], + ) prompt2.status = SequenceStatus.WAITING block_manager.allocate(seq_group2) # Swap seq group from CPU -> GPU. if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.LATER + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) + == AllocStatus.LATER + ) else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) + == AllocStatus.NEVER + ) @pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) @@ -386,11 +405,13 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching): block_size = 8 num_cpu_blocks = 1 num_gpu_blocks = 1 - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager( + block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching, + ) prompt_length = block_size - 3 assert prompt_length > 0 prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) @@ -414,13 +435,17 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching): # the total number of available GPU blocks then the swap # should fail. num_unseen_tokens = 1 - if (num_lookahead_slots + num_unseen_tokens + - prompt_length) <= (block_size * num_gpu_blocks): - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK + if (num_lookahead_slots + num_unseen_tokens + prompt_length) <= ( + block_size * num_gpu_blocks + ): + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) == AllocStatus.OK + ) else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER + assert ( + block_manager.can_swap_in(seq_group, num_lookahead_slots) + == AllocStatus.NEVER + ) # TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. @@ -430,8 +455,7 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching): @pytest.mark.parametrize("prompt_len", [10, 300, 1000]) @pytest.mark.parametrize("num_slots_to_append", [50]) @pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) -def test_sliding_window(block_size, prompt_len, num_slots_to_append, - sliding_window): +def test_sliding_window(block_size, prompt_len, num_slots_to_append, sliding_window): """Verify append_slots consumes the correct number of blocks from the block table. """ diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index ba085001136b..157bb58f5e0f 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -33,14 +33,17 @@ def test_allocate_naive(block_size: int, sequence_len: int): block_tables: list[BlockTable] = [] for i in range(5): - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc + assert ( + allocator.get_num_free_blocks(device=Device.GPU) + == num_gpu_blocks - i * num_blocks_per_alloc + ) block_tables.append( BlockTable( block_size=block_size, block_allocator=allocator, - )) + ) + ) block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) @@ -71,35 +74,33 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): token_ids = list(range(sequence_len)) chunked_tokens = list(chunk_list(token_ids, block_size)) - num_mutable_blocks_per_alloc = 0 if len( - chunked_tokens[-1]) == block_size else 1 - num_immutable_blocks_per_alloc = len( - chunked_tokens) - num_mutable_blocks_per_alloc + num_mutable_blocks_per_alloc = 0 if len(chunked_tokens[-1]) == block_size else 1 + num_immutable_blocks_per_alloc = len(chunked_tokens) - num_mutable_blocks_per_alloc block_tables: list[BlockTable] = [] for alloc_i in range(1, 6): - block_tables.append( BlockTable( block_size=block_size, block_allocator=allocator, - )) + ) + ) block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) # Expect all sequences to share allocations, except for their last block # (which may be mutable). - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - ( - num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc * - (alloc_i)) + assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks - ( + num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc * (alloc_i) + ) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("sequence_len", [1, 16, 129]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) @pytest.mark.parametrize("device", ["cpu", "gpu"]) -def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, - device: str): +def test_allocate_free( + block_size: int, sequence_len: int, allocator_type: str, device: str +): """Test the allocation and freeing of blocks using different allocators and devices. @@ -128,10 +129,11 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, for i in range(5): block_table.allocate(token_ids=token_ids, device=device) - assert allocator.get_num_free_blocks( - device) == num_device_blocks - num_blocks_per_alloc - assert all(block_id is not None - for block_id in block_table.physical_block_ids) + assert ( + allocator.get_num_free_blocks(device) + == num_device_blocks - num_blocks_per_alloc + ) + assert all(block_id is not None for block_id in block_table.physical_block_ids) block_table.free() assert allocator.get_num_free_blocks(device) == num_device_blocks @@ -141,8 +143,9 @@ def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, @pytest.mark.parametrize("sequence_len", [1, 16, 129]) @pytest.mark.parametrize("append_len", [1, 16, 129]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_allocation(block_size: int, sequence_len: int, - append_len: int, allocator_type: str): +def test_append_token_ids_allocation( + block_size: int, sequence_len: int, append_len: int, allocator_type: str +): """Test the allocation behavior when appending token IDs to a BlockTable. This test creates a CpuGpuBlockAllocator with the specified block size, @@ -169,29 +172,29 @@ def test_append_token_ids_allocation(block_size: int, sequence_len: int, block_allocator=allocator, ) - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + token_ids_to_append, - block_size))) - num_expected_blocks_before_append + num_expected_blocks_before_append = len(list(chunk_list(token_ids, block_size))) + num_expected_appended_blocks = ( + len(list(chunk_list(token_ids + token_ids_to_append, block_size))) + - num_expected_blocks_before_append + ) block_table.allocate(token_ids=token_ids, device=Device.GPU) - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append + assert len(block_table.physical_block_ids) == num_expected_blocks_before_append block_table.append_token_ids(token_ids_to_append) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks + assert ( + len(block_table.physical_block_ids) + == num_expected_blocks_before_append + num_expected_appended_blocks + ) @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("sequence_len", [1, 16, 129]) @pytest.mark.parametrize("num_empty_slots", [1, 16, 129]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, - num_empty_slots: int, - allocator_type: str): +def test_ensure_num_empty_slots_allocation( + block_size: int, sequence_len: int, num_empty_slots: int, allocator_type: str +): """Test the allocation behavior when ensuring a certain number of empty slots in a BlockTable. @@ -218,22 +221,22 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, block_allocator=allocator, ) - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + [-1] * num_empty_slots, - block_size))) - num_expected_blocks_before_append + num_expected_blocks_before_append = len(list(chunk_list(token_ids, block_size))) + num_expected_appended_blocks = ( + len(list(chunk_list(token_ids + [-1] * num_empty_slots, block_size))) + - num_expected_blocks_before_append + ) block_table.allocate(token_ids=token_ids, device=Device.GPU) # Assert that the empty slots consume the expected number of additional # blocks. - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append + assert len(block_table.physical_block_ids) == num_expected_blocks_before_append block_table.ensure_num_empty_slots(num_empty_slots) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks + assert ( + len(block_table.physical_block_ids) + == num_expected_blocks_before_append + num_expected_appended_blocks + ) # Now, ensure no additional blocks consumed as we fill up the empty slots. num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) @@ -246,9 +249,13 @@ def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, @pytest.mark.parametrize("append_len", [1, 16, 129]) @pytest.mark.parametrize("append_size", [1, 4, 129]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_correct_content(block_size: int, sequence_len: int, - append_len: int, allocator_type: str, - append_size: int): +def test_append_token_ids_correct_content( + block_size: int, + sequence_len: int, + append_len: int, + allocator_type: str, + append_size: int, +): """Verify token ids are correctly appended. Appends various amounts of token ids in various append sizes, and verifies the final sequence is correct. @@ -286,13 +293,13 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) def test_fork(seq_len: int, block_size: int, allocator_type: str): """Create a sequence using the specified allocator. - 1. Assert that after forking the sequence, the free block count is the - same. - 2. Assert that the forked sequence has the same physical mappings. - 3. Then free the original sequence; verify that the free block count is - the same. - 4. Finally, free the forked sequence and verify that the free block - count drops to zero. + 1. Assert that after forking the sequence, the free block count is the + same. + 2. Assert that the forked sequence has the same physical mappings. + 3. Then free the original sequence; verify that the free block count is + the same. + 4. Finally, free the forked sequence and verify that the free block + count drops to zero. """ num_gpu_blocks = 1024 @@ -312,30 +319,30 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): block_table.allocate(token_ids) - num_free_blocks_before_fork = allocator.get_num_free_blocks( - device=Device.GPU) + num_free_blocks_before_fork = allocator.get_num_free_blocks(device=Device.GPU) forked_block_table = block_table.fork() # Expect physical_block_ids and token_ids to match. - assert (block_table.physical_block_ids == - forked_block_table.physical_block_ids) - assert block_table._get_all_token_ids( - ) == forked_block_table._get_all_token_ids() + assert block_table.physical_block_ids == forked_block_table.physical_block_ids + assert block_table._get_all_token_ids() == forked_block_table._get_all_token_ids() # Do not expect any additional allocations. - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork + assert ( + allocator.get_num_free_blocks(device=Device.GPU) == num_free_blocks_before_fork + ) # Free the original blocks. Assert num free blocks does not change, since # refcount is nonzero. block_table.free() - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork + assert ( + allocator.get_num_free_blocks(device=Device.GPU) == num_free_blocks_before_fork + ) # Expect the forked block table to be unaffected by the free. - assert all(block_id is not None - for block_id in forked_block_table.physical_block_ids) + assert all( + block_id is not None for block_id in forked_block_table.physical_block_ids + ) # Free the forked blocks. Assert num free blocks does change, since # refcount is now zero. @@ -348,10 +355,14 @@ def test_fork(seq_len: int, block_size: int, allocator_type: str): @pytest.mark.parametrize("append_len", [1, 16, 129]) @pytest.mark.parametrize("appender", ["forked", "original"]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow(block_size: int, sequence_len: int, append_len: int, - allocator_type: str, appender: str): - """Fork a sequence; append to the forked sequence; verify there's a CoW. - """ +def test_cow( + block_size: int, + sequence_len: int, + append_len: int, + allocator_type: str, + appender: str, +): + """Fork a sequence; append to the forked sequence; verify there's a CoW.""" num_gpu_blocks = 1024 allocator = CpuGpuBlockAllocator.create( @@ -370,8 +381,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, ) num_expected_non_cow_blocks = cdiv(sequence_len, block_size) - num_expected_cow_blocks = cdiv(sequence_len + append_len, - block_size) - (sequence_len // block_size) + num_expected_cow_blocks = cdiv(sequence_len + append_len, block_size) - ( + sequence_len // block_size + ) original_block_table.allocate(token_ids=token_ids, device=Device.GPU) original_block_ids = original_block_table.physical_block_ids[:] @@ -380,8 +392,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, forked_block_table = original_block_table.fork() # Expect no additional allocation (copy on _write_). - assert allocator.get_num_free_blocks( - Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks) + assert allocator.get_num_free_blocks(Device.GPU) == ( + num_gpu_blocks - num_expected_non_cow_blocks + ) if appender == "forked": appender_block_table = forked_block_table @@ -400,9 +413,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, assert appender_block_table.physical_block_ids != original_block_ids # Expect the blocks changed during append to have a CoW. - assert allocator.get_num_free_blocks( - Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks + - num_expected_cow_blocks) + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - ( + num_expected_non_cow_blocks + num_expected_cow_blocks + ) cows = allocator.clear_copy_on_writes() if sequence_len % block_size > 0: @@ -432,9 +445,14 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, @pytest.mark.parametrize("lookahead_slots", [1, 16, 129]) @pytest.mark.parametrize("appender", ["forked", "original"]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow_lookahead_simple(block_size: int, sequence_len: int, - append_len: int, lookahead_slots: int, - allocator_type: str, appender: str): +def test_cow_lookahead_simple( + block_size: int, + sequence_len: int, + append_len: int, + lookahead_slots: int, + allocator_type: str, + appender: str, +): """Similar to test_cow, except with lookahead allocation. The assertions are less rigorous due to the complexity of the property under test. """ @@ -507,10 +525,13 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, @pytest.mark.parametrize("num_new_tokens", [1, 16, 129]) @pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, - num_new_tokens: int, - num_lookahead_slots: int, - allocator_type: str): +def test_num_blocks_touched_by_append_slots( + block_size: int, + sequence_len: int, + num_new_tokens: int, + num_lookahead_slots: int, + allocator_type: str, +): """Verify correct calculation of get_num_blocks_touched_by_append_slots. This is done by using copy-on-write, which requires any modified block to @@ -547,10 +568,9 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, _ = block_table.fork() # Determine how many blocks should be touched. - expected_num_touched_blocks = ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=token_ids_to_append, - num_lookahead_slots=num_lookahead_slots)) + expected_num_touched_blocks = block_table.get_num_blocks_touched_by_append_slots( + token_ids=token_ids_to_append, num_lookahead_slots=num_lookahead_slots + ) # Measure how many blocks are touched by measuring num_free_blocks before # and after the append. @@ -558,8 +578,9 @@ def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, # We expect append_token_ids to CoW all mutated blocks that have refcount>1. num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) - num_consumed_blocks = (num_free_blocks_before_append - - allocator.get_num_free_blocks(Device.GPU)) + num_consumed_blocks = num_free_blocks_before_append - allocator.get_num_free_blocks( + Device.GPU + ) # TODO(cade) ensure equality when num_lookahead_slots > 0. # The reason we have < is because lookahead blocks are not copied eagerly; diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py index 795eef6743fd..1b2151fcf2d2 100644 --- a/tests/core/block/test_cpu_gpu_block_allocator.py +++ b/tests/core/block/test_cpu_gpu_block_allocator.py @@ -11,8 +11,9 @@ @pytest.mark.parametrize("num_gpu_blocks", [1024]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): +def test_allocate_mutable_block( + num_cpu_blocks: int, num_gpu_blocks: int, block_size: int, allocator_type: str +): allocator = CpuGpuBlockAllocator.create( allocator_type=allocator_type, num_gpu_blocks=num_gpu_blocks, @@ -50,8 +51,9 @@ def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, @pytest.mark.parametrize("num_gpu_blocks", [1024]) @pytest.mark.parametrize("block_size", [2]) @pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): +def test_allocate_immutable_block( + num_cpu_blocks: int, num_gpu_blocks: int, block_size: int, allocator_type: str +): allocator = CpuGpuBlockAllocator.create( allocator_type=allocator_type, num_gpu_blocks=num_gpu_blocks, @@ -59,29 +61,30 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, block_size=block_size, ) - unique_token_ids = list( - range((num_cpu_blocks + num_gpu_blocks) * block_size)) + unique_token_ids = list(range((num_cpu_blocks + num_gpu_blocks) * block_size)) gpu_token_ids = list( - chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) + chunk_list(unique_token_ids[: num_gpu_blocks * block_size], block_size) + ) cpu_token_ids = list( - chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size)) + chunk_list(unique_token_ids[num_gpu_blocks * block_size :], block_size) + ) assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks cpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.CPU) + allocator.allocate_immutable_block( + prev_block=None, token_ids=token_ids, device=Device.CPU + ) for token_ids in cpu_token_ids ] assert allocator.get_num_free_blocks(Device.CPU) == 0 assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks gpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.GPU) + allocator.allocate_immutable_block( + prev_block=None, token_ids=token_ids, device=Device.GPU + ) for token_ids in gpu_token_ids ] assert allocator.get_num_free_blocks(Device.CPU) == 0 diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index a31d1c46b37f..1e2e104c6113 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -10,18 +10,21 @@ class TestNaiveBlockAllocator: - @staticmethod - def create_allocate_lambda(allocate_type: str, - allocator: NaiveBlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): + def create_allocate_lambda( + allocate_type: str, + allocator: NaiveBlockAllocator, + prev_block: Optional[Block], + token_ids: list[int], + ): if allocate_type == "immutable": allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) + prev_block=prev_block, token_ids=token_ids + ) elif allocate_type == "mutable": allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) + prev_block=prev_block + ) else: raise ValueError() @@ -31,16 +34,13 @@ def create_allocate_lambda(allocate_type: str, @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_ooms(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) + def test_allocate_ooms(allocate_type: str, num_blocks: int, block_size: int): + allocator = NaiveBlockAllocator( + create_block=NaiveBlock, num_blocks=num_blocks, block_size=block_size + ) allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) + allocate_type, allocator, prev_block=None, token_ids=list(range(block_size)) + ) [allocate_block() for _ in range(num_blocks)] with pytest.raises(BlockAllocator.NoFreeBlocksError): @@ -50,16 +50,13 @@ def test_allocate_ooms(allocate_type: str, num_blocks: int, @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) - def test_free_prevents_oom(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) + def test_free_prevents_oom(allocate_type: str, num_blocks: int, block_size: int): + allocator = NaiveBlockAllocator( + create_block=NaiveBlock, num_blocks=num_blocks, block_size=block_size + ) allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) + allocate_type, allocator, prev_block=None, token_ids=list(range(block_size)) + ) blocks = [allocate_block() for _ in range(num_blocks)] @@ -85,16 +82,13 @@ def test_free_prevents_oom(allocate_type: str, num_blocks: int, @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) @pytest.mark.parametrize("num_blocks", [1024]) @pytest.mark.parametrize("block_size", [16]) - def test_get_num_free_blocks(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) + def test_get_num_free_blocks(allocate_type: str, num_blocks: int, block_size: int): + allocator = NaiveBlockAllocator( + create_block=NaiveBlock, num_blocks=num_blocks, block_size=block_size + ) allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) + allocate_type, allocator, prev_block=None, token_ids=list(range(block_size)) + ) assert allocator.get_num_free_blocks() == num_blocks @@ -108,41 +102,37 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int, @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): - """ Verify the allocator can correctly return the number of + """Verify the allocator can correctly return the number of full blocks touched. """ - allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) + allocator_src = NaiveBlockAllocator( + create_block=NaiveBlock, num_blocks=num_blocks, block_size=block_size + ) + allocator_dst = NaiveBlockAllocator( + create_block=NaiveBlock, num_blocks=num_blocks, block_size=block_size + ) # Create a chain of cacheable blocks in the dst allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( "immutable", allocator_src, prev_block=None, - token_ids=list(range(block_size))) + token_ids=list(range(block_size)), + ) src_blocks = [allocate_block() for _ in range(num_blocks - 1)] # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 + assert allocator_dst.get_num_full_blocks_touched(src_blocks) == num_blocks - 1 # Insert one non-full block in the src - allocate_non_full_block = \ - TestNaiveBlockAllocator.create_allocate_lambda( - "mutable", allocator_src, - prev_block=src_blocks[-1],token_ids=[] - ) + allocate_non_full_block = TestNaiveBlockAllocator.create_allocate_lambda( + "mutable", allocator_src, prev_block=src_blocks[-1], token_ids=[] + ) src_blocks.append(allocate_non_full_block()) src_blocks[-1].append_token_ids([0]) - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 + assert allocator_dst.get_num_full_blocks_touched(src_blocks) == num_blocks - 1 # Fill up the last source block and then invoke # get_num_blocks_touched src_blocks[-1].append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks + assert allocator_dst.get_num_full_blocks_touched(src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 46e224c6f53b..6236eddf33b4 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -11,33 +11,37 @@ from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - PrefixCachingBlock, - PrefixCachingBlockAllocator) +from vllm.core.block.prefix_caching_block import ( + ComputedBlocksTracker, + PrefixCachingBlock, + PrefixCachingBlockAllocator, +) from vllm.sequence import Logprob from vllm.utils import Device class TestPrefixCachingBlock: - @staticmethod @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("block_size", [1, 16]) @pytest.mark.parametrize("is_curr_block_full", [True, False]) - def test_first_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool): - """Verify a block which is first in the sequence has the correct hash. - """ + def test_first_block_has_correct_content_hash( + seed: int, block_size: int, is_curr_block_full: bool + ): + """Verify a block which is first in the sequence has the correct hash.""" random.seed(seed) - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) + num_to_fill = ( + block_size if is_curr_block_full else random.randint(0, block_size - 1) + ) token_ids = list(range(num_to_fill)) mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - block_with_prev = PrefixCachingBlock(prev_block=None, - token_ids=token_ids, - block_size=block_size, - allocator=mock_allocator) + block_with_prev = PrefixCachingBlock( + prev_block=None, + token_ids=token_ids, + block_size=block_size, + allocator=mock_allocator, + ) if is_curr_block_full: # Expect hash since block is full. @@ -45,7 +49,9 @@ def test_first_block_has_correct_content_hash(seed: int, block_size: int, PrefixCachingBlock.hash_block_tokens( is_first_block=True, prev_block_hash=None, - cur_block_token_ids=token_ids)) + cur_block_token_ids=token_ids, + ) + ) else: # Do not expect hash since block is not full. assert block_with_prev.content_hash is None @@ -55,9 +61,9 @@ def test_first_block_has_correct_content_hash(seed: int, block_size: int, @pytest.mark.parametrize("block_size", [1, 16]) @pytest.mark.parametrize("is_curr_block_full", [True, False]) @pytest.mark.parametrize("prev_block_has_hash", [True, False]) - def test_nth_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool, - prev_block_has_hash: bool): + def test_nth_block_has_correct_content_hash( + seed: int, block_size: int, is_curr_block_full: bool, prev_block_has_hash: bool + ): """Verify a block which is not first in the sequence has the correct hash. """ @@ -66,11 +72,13 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int, previous_block = MagicMock(spec=PrefixCachingBlock) prev_block_hash = random.randint(0, 1000) - previous_block.content_hash = (prev_block_hash if prev_block_has_hash - else hash('None')) + previous_block.content_hash = ( + prev_block_hash if prev_block_has_hash else hash("None") + ) - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) + num_to_fill = ( + block_size if is_curr_block_full else random.randint(0, block_size - 1) + ) token_ids = list(range(num_to_fill)) mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) @@ -83,11 +91,11 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int, if is_curr_block_full and prev_block_has_hash: # Expect hash since block is full and previous block has hash. - assert (block_with_prev.content_hash == - PrefixCachingBlock.hash_block_tokens( - is_first_block=False, - prev_block_hash=prev_block_hash, - cur_block_token_ids=token_ids)) + assert block_with_prev.content_hash == PrefixCachingBlock.hash_block_tokens( + is_first_block=False, + prev_block_hash=prev_block_hash, + cur_block_token_ids=token_ids, + ) else: # Do not expect hash since block is not full or the previous block # does not have a hash. @@ -97,9 +105,9 @@ def test_nth_block_has_correct_content_hash(seed: int, block_size: int, @pytest.mark.parametrize("block_size", [1, 2, 16]) @pytest.mark.parametrize("num_tokens", list(range(3))) @pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10]) - def test_blocks_have_correct_hash_in_chain(block_size: int, - num_tokens: int, - num_empty_trailing_blocks: int): + def test_blocks_have_correct_hash_in_chain( + block_size: int, num_tokens: int, num_empty_trailing_blocks: int + ): """Create two chains of logical blocks with the same contents. Assert the hashes are equal. """ @@ -107,30 +115,29 @@ def test_blocks_have_correct_hash_in_chain(block_size: int, token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] - first_chain, second_chain = (TestPrefixCachingBlock.create_chain( - block_size=block_size, - token_ids=token_ids, - num_empty_trailing_blocks=num_empty_trailing_blocks) - for _ in range(2)) + first_chain, second_chain = ( + TestPrefixCachingBlock.create_chain( + block_size=block_size, + token_ids=token_ids, + num_empty_trailing_blocks=num_empty_trailing_blocks, + ) + for _ in range(2) + ) - for first_chain_block, second_chain_block in zip( - first_chain, second_chain): - assert (first_chain_block.content_hash == - second_chain_block.content_hash) + for first_chain_block, second_chain_block in zip(first_chain, second_chain): + assert first_chain_block.content_hash == second_chain_block.content_hash if not first_chain or not second_chain: assert first_chain == second_chain assert num_tokens == 0 @staticmethod - def create_chain(block_size: int, - token_ids: list[int], - num_empty_trailing_blocks=0) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ + def create_chain( + block_size: int, token_ids: list[int], num_empty_trailing_blocks=0 + ) -> list[PrefixCachingBlock]: + """Helper method which creates a chain of blocks.""" blocks: list[PrefixCachingBlock] = [] - num_blocks = math.ceil( - len(token_ids) / block_size) + num_empty_trailing_blocks + num_blocks = math.ceil(len(token_ids) / block_size) + num_empty_trailing_blocks if num_blocks == 0: return [] @@ -146,9 +153,9 @@ def create_chain(block_size: int, allocator=allocator, ) - tokens_to_append = token_ids[block_number * - block_size:(block_number + 1) * - block_size] + tokens_to_append = token_ids[ + block_number * block_size : (block_number + 1) * block_size + ] if tokens_to_append: prev_block.append_token_ids(tokens_to_append) @@ -158,17 +165,21 @@ def create_chain(block_size: int, class TestPrefixCachingBlockAllocator: - @staticmethod - def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): + def create_allocate_lambda( + allocate_type: str, + allocator: BlockAllocator, + prev_block: Optional[Block], + token_ids: list[int], + ): if allocate_type == "immutable": allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) + prev_block=prev_block, token_ids=token_ids + ) elif allocate_type == "mutable": allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) + prev_block=prev_block + ) else: raise ValueError() @@ -178,8 +189,9 @@ def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) def test_allocate_mutable_ooms(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( allocate_type="mutable", allocator=allocator, @@ -195,9 +207,11 @@ def test_allocate_mutable_ooms(num_blocks: int, block_size: int): @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) def test_allocate_immutable_does_not_oom_single_hash( - num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + num_blocks: int, block_size: int + ): + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( allocate_type="immutable", allocator=allocator, @@ -212,20 +226,20 @@ def test_allocate_immutable_does_not_oom_single_hash( # Expect all blocks to have same physical block index. for block in blocks: - assert (block.block_id == non_oom_block.block_id) + assert block.block_id == non_oom_block.block_id @staticmethod @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_immutable_ooms_many_hash(num_blocks: int, - block_size: int): + def test_allocate_immutable_ooms_many_hash(num_blocks: int, block_size: int): """Consume all blocks using many different hashes/block content. Do this by creating a sequence that is very long. Expect next block to OOM. """ - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) # Create token ids that will exhaust all blocks. token_ids = list(range(num_blocks * block_size)) @@ -238,9 +252,9 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int, # Expect allocation with unseen hash to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_immutable_block(prev_block=chain[-1], - token_ids=list( - range(block_size))) + allocator.allocate_immutable_block( + prev_block=chain[-1], token_ids=list(range(block_size)) + ) # Expect mutable allocation to fail. with pytest.raises(BlockAllocator.NoFreeBlocksError): @@ -256,14 +270,15 @@ def test_allocate_immutable_ooms_many_hash(num_blocks: int, # Expect physical block indices to be the same in both chains. assert chain and second_chain for first_chain_block, second_chain_block in zip(chain, second_chain): - assert (first_chain_block.block_id == second_chain_block.block_id) + assert first_chain_block.block_id == second_chain_block.block_id @staticmethod @pytest.mark.parametrize("num_blocks", [1, 1024]) @pytest.mark.parametrize("block_size", [1, 16]) def test_free_prevents_oom(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) # Create token ids that will exhaust all blocks. token_ids = list(range(num_blocks * block_size)) @@ -300,8 +315,9 @@ def test_free_prevents_oom(num_blocks: int, block_size: int): @pytest.mark.parametrize("seed", list(range(20))) def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. @@ -316,23 +332,24 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): # Free each block in chain, assert num free blocks includes new free # block. for i, block in enumerate(chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) + assert allocator.get_num_free_blocks() == ( + num_blocks - num_blocks_to_consume + i + ) allocator.free(block) @staticmethod @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) - def test_prefix_caching_block_get_num_full_blocks_touched( - num_blocks, block_size): - """ Verify the allocator can correctly return the number of + def test_prefix_caching_block_get_num_full_blocks_touched(num_blocks, block_size): + """Verify the allocator can correctly return the number of blocks touched, when there are cached prefixes. """ - allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator_src = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) + allocator_dst = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) # Create token ids that will exhaust all blocks except the last token_ids = list(range((num_blocks - 1) * block_size)) @@ -345,49 +362,43 @@ def test_prefix_caching_block_get_num_full_blocks_touched( ) # Create a chain of the same blocks in the src - blocks_to_swap_in = \ - TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator_src, - ) + blocks_to_swap_in = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator_src, + ) # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 0 + assert allocator_dst.get_num_full_blocks_touched(blocks_to_swap_in) == 0 # Free the first block in the dst allocator_dst.free(cached_blocks[0]) # Now the first block becomes dangling, the swapped blocks need # to reclaim the first block in the dst - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 + assert allocator_dst.get_num_full_blocks_touched(blocks_to_swap_in) == 1 # Insert one non-full block in the src - non_full_block = allocator_src.allocate_mutable_block( - blocks_to_swap_in[-1]) + non_full_block = allocator_src.allocate_mutable_block(blocks_to_swap_in[-1]) non_full_block.append_token_ids([0]) blocks_to_swap_in.append(non_full_block) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 + assert allocator_dst.get_num_full_blocks_touched(blocks_to_swap_in) == 1 # Fill up the last mutable block and invoke get_num_blocks_touched. # Note: The last block is not cached so it will be touched. non_full_block.append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 2 + assert allocator_dst.get_num_full_blocks_touched(blocks_to_swap_in) == 2 @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("seed", list(range(20))) - def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, - seed: int): + def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, seed: int): """Verify sharing occurs by allocating two sequences that share prefixes and incrementally freeing blocks. """ random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. @@ -407,32 +418,33 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, # Free each block in the first chain. Since all blocks are shared, the # free count should stay constant. for i, block in enumerate(first_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume) + assert allocator.get_num_free_blocks() == ( + num_blocks - num_blocks_to_consume + ) allocator.free(block) # Free each block in the second chain. Since the refcount is now zero, # the free count should increment with each free. for i, block in enumerate(second_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) + assert allocator.get_num_free_blocks() == ( + num_blocks - num_blocks_to_consume + i + ) allocator.free(block) @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("seed", list(range(20))) - def test_get_common_computed_block_ids(num_blocks: int, block_size: int, - seed: int): + def test_get_common_computed_block_ids(num_blocks: int, block_size: int, seed: int): """Verify get_common_computed_block_ids could get correct result by create two immutable chain sharing prefix at specified pos, and compare whether we also could get right result from get_common_computed_block_ids. """ random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks * 2, block_size=block_size + ) num_blocks_to_consume = random.randint(1, num_blocks - 1) # Create token ids that will exhaust all blocks. @@ -463,9 +475,10 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, second_chain[i].block_id for i in range(num_blocks_to_consume) ] res = allocator.get_common_computed_block_ids( - [first_computed_ids, second_computed_ids]) + [first_computed_ids, second_computed_ids] + ) - assert (len(res) == zero_point_blocks) + assert len(res) == zero_point_blocks # Test case that assume those prompted block after first immutable would # be freed into hashless allocator, while first immutable block get ref @@ -477,12 +490,12 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) token_ids = list(range(block_size)) - block = allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) + block = allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) assert allocator._refcounter.get(block.block_id) == 1 m = allocator.allocate_mutable_block(prev_block=None) @@ -511,15 +524,17 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): all_blocks_list = [i for i in range(num_blocks)] zero_ref = {i: 0 for i in range(num_blocks)} one_ref = {i: 1 for i in range(num_blocks)} - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) token_ids = list(range(num_blocks * block_size)) # Verify initial/pre-alloc state # Ensure all blocks are free inside hashless allocator - assert list(allocator._hashless_allocator._free_block_indices - ) == all_blocks_list + assert ( + list(allocator._hashless_allocator._free_block_indices) == all_blocks_list + ) # Ensure no tracked blocks assert len(allocator._block_tracker.keys()) == num_blocks for block_id in range(num_blocks): @@ -536,13 +551,14 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): for i in range(num_blocks): block = allocator.allocate_immutable_block( prev_block=None, - token_ids=token_ids[block_size * i:block_size * (i + 1)]) + token_ids=token_ids[block_size * i : block_size * (i + 1)], + ) new_block.append(block) # Verify post-alloc state # Ensure no blocks are free inside hashless allocator - assert (len(allocator._hashless_allocator._free_block_indices) == 0) + assert len(allocator._hashless_allocator._free_block_indices) == 0 # Ensure all blocks are tracked assert len(allocator._block_tracker.keys()) == num_blocks for block_id in range(num_blocks): @@ -601,7 +617,8 @@ def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): # shall get free block from hashless allocator, thus no block left # in hashless block = allocator.allocate_immutable_block( - prev_block=None, token_ids=token_ids[:block_size]) + prev_block=None, token_ids=token_ids[:block_size] + ) assert block.block_id == 0 assert len(allocator._hashless_allocator._free_block_indices) == 0 @@ -632,8 +649,9 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): """ random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) num_blocks_to_consume = num_blocks + 1 token_ids = list(range(num_blocks_to_consume * block_size)) @@ -647,8 +665,9 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): allocator=allocator, ) # There should only be one block allocated at this point - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_in_first_chain) + assert allocator.get_num_free_blocks() == ( + num_blocks - num_blocks_in_first_chain + ) # Set the last accessed time of the first block to 1 blocks_ids = [block.block_id for block in first_chain] @@ -693,26 +712,22 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): @staticmethod def test_metric(): block_size = 16 - allocator = PrefixCachingBlockAllocator(num_blocks=4, - block_size=block_size) + allocator = PrefixCachingBlockAllocator(num_blocks=4, block_size=block_size) # Test when no query (0/0) assert allocator.get_prefix_cache_hit_rate() == 0.0 token_ids = list(range(block_size)) - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) + allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) # Test 0/1 hit rate assert allocator.get_prefix_cache_hit_rate() == 0.0 - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) + allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) # Test 1/2 hit rate assert allocator.get_prefix_cache_hit_rate() == 0.5 # Test more than one block for _ in range(2, 1005): - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) + allocator.allocate_immutable_block(prev_block=None, token_ids=token_ids) assert allocator.get_prefix_cache_hit_rate() > 0.99 # Test case for marking cache hit blocks as computed right after @@ -721,8 +736,7 @@ def test_metric(): def test_touch_block(): block_size = 16 common_blocks = 4 - allocator = PrefixCachingBlockAllocator(num_blocks=8, - block_size=block_size) + allocator = PrefixCachingBlockAllocator(num_blocks=8, block_size=block_size) common_token_ids = list(range(block_size * common_blocks)) @@ -737,13 +751,13 @@ def test_touch_block(): block_hashes = [block.content_hash for block in blocks] # The allocated blocks should be marked as touched # but not computed. - computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes) + computed_block_ids = allocator.find_cached_blocks_prefix(block_hashes) assert len(computed_block_ids) == 0 allocator.mark_blocks_as_computed([]) computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes) + block_hashes=block_hashes + ) assert len(computed_block_ids) == common_blocks @staticmethod @@ -754,11 +768,12 @@ def test_find_cached_blocks_prefix(): block_size = 4 num_blocks = 8 total_test_blocks = 12 - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) token_ids = list(range(total_test_blocks * block_size)) - block_tokens_seq1 = token_ids[:num_blocks * block_size] + block_tokens_seq1 = token_ids[: num_blocks * block_size] blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, token_ids=block_tokens_seq1, @@ -769,7 +784,8 @@ def test_find_cached_blocks_prefix(): # All blocks should be cached. cached_blocks_seq1 = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) + block_hashes=block_hashes_seq1 + ) assert len(cached_blocks_seq1) == num_blocks # Free the first sequence. @@ -778,10 +794,11 @@ def test_find_cached_blocks_prefix(): # All blocks should be still be cached if not required to be allocated. cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) + block_hashes=block_hashes_seq1 + ) assert len(cached_blocks) == num_blocks - block_tokens_seq2 = token_ids[num_blocks * block_size:] + block_tokens_seq2 = token_ids[num_blocks * block_size :] blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, token_ids=block_tokens_seq2, @@ -790,13 +807,15 @@ def test_find_cached_blocks_prefix(): block_hashes_seq2 = [block.content_hash for block in blocks_seq2] allocator.mark_blocks_as_computed([]) cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq2) + block_hashes=block_hashes_seq2 + ) assert len(cached_blocks) == len(blocks_seq2) # Half of the blocks from seq1 should still be cached. num_evicted_blocks = len(blocks_seq2) cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) + block_hashes=block_hashes_seq1 + ) assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks # Test reset prefix cache @@ -806,8 +825,9 @@ def test_find_cached_blocks_prefix(): def test_reset_prefix_cache(num_blocks: int, block_size: int): """This test case simulates the case of resetting the prefix cache.""" - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) + allocator = PrefixCachingBlockAllocator( + num_blocks=num_blocks, block_size=block_size + ) token_ids = list(range(3 * block_size)) first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( @@ -844,8 +864,7 @@ def create_immutable_chain( allocator: PrefixCachingBlockAllocator, extra_hash: Optional[int] = None, ) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ + """Helper method which creates a chain of blocks.""" blocks: list[Block] = [] num_blocks = math.ceil(len(token_ids) / block_size) @@ -854,20 +873,18 @@ def create_immutable_chain( prev_block = None for block_number in range(0, num_blocks): - block_token_ids = token_ids[block_number * - block_size:(block_number + 1) * - block_size] + block_token_ids = token_ids[ + block_number * block_size : (block_number + 1) * block_size + ] prev_block = allocator.allocate_immutable_block( - prev_block=prev_block, - token_ids=block_token_ids, - extra_hash=extra_hash) + prev_block=prev_block, token_ids=block_token_ids, extra_hash=extra_hash + ) blocks.append(prev_block) return blocks class TestComputedBlocksTracker: - @staticmethod def _get_mock_allocator(): return MagicMock(spec=PrefixCachingBlockAllocator) @@ -898,9 +915,9 @@ def test_get_num_cached_tokens(): # Not yet allocated. tokens = [0, 1, 2, 3, 4, 5] - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) + seq1 = create_dummy_sequence( + request_id=0, token_ids=tokens, block_size=block_size + ) mock_allocator.find_cached_blocks_prefix.return_value = [] assert tracker.get_num_cached_tokens(seq1) == 0 @@ -934,11 +951,10 @@ def test_get_num_cached_tokens(): tracker.remove_seq(seq1.seq_id) # Re-create the sequence with the same request id to simulate recompute. - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - mock_allocator.find_cached_blocks_prefix.return_value = [ - ] # no cached block + seq1 = create_dummy_sequence( + request_id=0, token_ids=tokens, block_size=block_size + ) + mock_allocator.find_cached_blocks_prefix.return_value = [] # no cached block assert tracker.get_num_cached_tokens(seq1) == 0 @staticmethod @@ -964,9 +980,9 @@ def test_correct_block_hash(): ) tokens = list(range(block_size * 4)) # 4 blocks. - seq = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) + seq = create_dummy_sequence( + request_id=0, token_ids=tokens, block_size=block_size + ) _ = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, token_ids=tokens, @@ -1001,10 +1017,9 @@ def test_correct_extra_hash(): tokens = list(range(block_size * 4)) # Create a dummy LoRA sequence with a specific LoRA ID. - lora_seq = create_dummy_lora_sequence(request_id=0, - token_ids=tokens, - block_size=block_size, - lora_int_id=1) + lora_seq = create_dummy_lora_sequence( + request_id=0, token_ids=tokens, block_size=block_size, lora_int_id=1 + ) _ = TestPrefixCachingBlockAllocator.create_immutable_chain( block_size=block_size, @@ -1017,14 +1032,13 @@ def test_correct_extra_hash(): # Create different dummy sequences that have the same token IDs # but different LoRA IDs. - seq = create_dummy_sequence(request_id=1, - token_ids=tokens, - block_size=block_size) - - different_lora_seq = create_dummy_lora_sequence(request_id=2, - token_ids=tokens, - block_size=block_size, - lora_int_id=2) + seq = create_dummy_sequence( + request_id=1, token_ids=tokens, block_size=block_size + ) + + different_lora_seq = create_dummy_lora_sequence( + request_id=2, token_ids=tokens, block_size=block_size, lora_int_id=2 + ) # Due to the different LoRA IDs, corresponding blocks are not cached. assert tracker.get_num_cached_tokens(seq) == 0 diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 375b248ebeda..a6a8b33e19d3 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -9,4 +9,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index d4dacc4f1296..a5fb2f966248 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -37,11 +37,13 @@ def test_simple(): num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -50,9 +52,9 @@ def test_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=block_size, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -61,8 +63,11 @@ def test_simple(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == num_seq_group for s in running: append_new_token(s, 1) @@ -71,8 +76,11 @@ def test_simple(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == num_seq_group @@ -97,9 +105,9 @@ def test_chunk(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -127,7 +135,7 @@ def test_chunk(): def test_concurrent_chunking(): - """Verify prefills are chunked properly when + """Verify prefills are chunked properly when --max-num-partial-prefills is > 1""" block_size = 4 max_seqs = 60 @@ -149,9 +157,9 @@ def test_concurrent_chunking(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -196,7 +204,8 @@ def test_concurrent_chunking_large_requests(): _, seq_group = create_dummy_prompt( str(i), prompt_length=1200, # Very large prompt - block_size=block_size) + block_size=block_size, + ) scheduler.add_seq_group(seq_group) # Verify only a single request is chunked, and it gets all 64 tokens @@ -208,7 +217,7 @@ def test_concurrent_chunking_large_requests(): def test_short_prompts_jump_long_prompts_in_queue(): - """Verify large prefill requests are punted behind smaller ones if + """Verify large prefill requests are punted behind smaller ones if another large prefill request is already running""" block_size = 4 max_seqs = 60 @@ -234,7 +243,8 @@ def test_short_prompts_jump_long_prompts_in_queue(): _, seq_group = create_dummy_prompt( str(i), prompt_length=1200, # Very large prompt - block_size=block_size) + block_size=block_size, + ) scheduler.add_seq_group(seq_group) long_seqs.append(seq_group) assert seq_group.is_prefill() @@ -244,7 +254,8 @@ def test_short_prompts_jump_long_prompts_in_queue(): _, seq_group = create_dummy_prompt( str(i + 2), prompt_length=40, # Very small prompt - block_size=block_size) + block_size=block_size, + ) scheduler.add_seq_group(seq_group) short_seqs.append(seq_group) assert seq_group.is_prefill() @@ -372,9 +383,9 @@ def test_complex(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -395,9 +406,9 @@ def test_complex(): # Add 2 more requests. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -443,9 +454,9 @@ def test_maximal_decoding(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=2, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=2, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -462,9 +473,7 @@ def test_maximal_decoding(): append_new_token(running[0], 1) # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", - prompt_length=2, - block_size=block_size) + _, seq_group = create_dummy_prompt("3", prompt_length=2, block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -535,9 +544,7 @@ def test_prompt_limit(): scheduler = Scheduler(scheduler_config, cache_config, None) running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", - prompt_length=48, - block_size=block_size) + _, seq_group = create_dummy_prompt("1", prompt_length=48, block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -556,19 +563,19 @@ def test_prompt_limit_exceed(): max_seqs = 64 max_model_len = 32 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("2", - prompt_length=48, - block_size=block_size) + _, seq_group = create_dummy_prompt("2", prompt_length=48, block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -595,9 +602,7 @@ def test_chunked_prefill_preempt(): cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt("1", prompt_length=60, block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -613,8 +618,7 @@ def test_chunked_prefill_preempt(): def cannot_append_second_group1(seq_group, num_lookahead_slots): return seq_group.request_id != "1" - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group1) + scheduler.block_manager.can_append_slots.side_effect = cannot_append_second_group1 # The running prefill is now preempted. _, out = schedule_and_update_computed_tokens(scheduler) @@ -635,8 +639,7 @@ def cannot_append_second_group1(seq_group, num_lookahead_slots): def cannot_append_second_group2(seq_group, num_lookahead_slots): return True - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group2) + scheduler.block_manager.can_append_slots.side_effect = cannot_append_second_group2 _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 assert out.num_prefill_groups == 1 @@ -668,9 +671,7 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps): cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", - prompt_length=30, - block_size=block_size) + _, seq_group = create_dummy_prompt("1", prompt_length=30, block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -679,8 +680,9 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps): assert out.num_prefill_groups == 1 assert out.num_batched_tokens == max_num_batched_tokens print(out.num_lookahead_slots) - assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else - num_lookahead_slots) + assert out.num_lookahead_slots == ( + 0 if (num_scheduler_steps == 1) else num_lookahead_slots + ) def test_chunked_prefill_max_seqs(): @@ -701,9 +703,7 @@ def test_chunked_prefill_max_seqs(): scheduler = Scheduler(scheduler_config, cache_config, None) running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", - prompt_length=65, - block_size=block_size) + _, seq_group = create_dummy_prompt("1", prompt_length=65, block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # The first prefill is chunked. @@ -713,9 +713,9 @@ def test_chunked_prefill_max_seqs(): # Add new requests. for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=65, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=65, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -750,11 +750,7 @@ def test_prefix_caching(): max_model_len, enable_chunked_prefill=True, ) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto", enable_prefix_caching=True) cache_config.num_cpu_blocks = 0 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) @@ -762,9 +758,9 @@ def test_prefix_caching(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) + _, seq_group = create_dummy_prompt( + str(i), block_size=block_size, prompt_length=50 + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -780,23 +776,21 @@ def test_prefix_caching(): def test_prefix_caching_with_concurrent_partial_prefills(): - """Verify allocating full blocks when prefix caching is enabled with + """Verify allocating full blocks when prefix caching is enabled with --max-num-partial-prefills > 1.""" block_size = 4 max_seqs = 10 max_model_len = 8000 max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto", enable_prefix_caching=True) cache_config.num_cpu_blocks = 0 cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) @@ -804,9 +798,9 @@ def test_prefix_caching_with_concurrent_partial_prefills(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) + _, seq_group = create_dummy_prompt( + str(i), block_size=block_size, prompt_length=50 + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -833,9 +827,8 @@ def test_prefix_caching_with_concurrent_partial_prefills(): @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) -def test_chunked_prefill_with_actual_engine(model: str, - max_num_partial_prefills: int): - """Make sure the model can actually sample with concurrent +def test_chunked_prefill_with_actual_engine(model: str, max_num_partial_prefills: int): + """Make sure the model can actually sample with concurrent partial prefills """ diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 1b958e34df87..8616432905f7 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -20,29 +20,30 @@ def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): @pytest.mark.parametrize("num_scheduler_steps", [1, 8]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("enforce_eager", [False, True]) -def test_num_computed_tokens_update(num_scheduler_steps: int, - enable_chunked_prefill: bool, - enforce_eager: bool): - +def test_num_computed_tokens_update( + num_scheduler_steps: int, enable_chunked_prefill: bool, enforce_eager: bool +): is_multi_step = num_scheduler_steps > 1 is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill if is_multi_step_chunked_prefill and current_platform.is_rocm(): - pytest.skip("Multi-step with Chunked-Prefill does not support " - "rocm_flash_attn backend") + pytest.skip( + "Multi-step with Chunked-Prefill does not support rocm_flash_attn backend" + ) # Make a vllm engine - runner = VllmRunner(model_name=MODEL, - gpu_memory_utilization=0.7, - num_scheduler_steps=num_scheduler_steps, - enable_chunked_prefill=enable_chunked_prefill, - enforce_eager=enforce_eager) + runner = VllmRunner( + model_name=MODEL, + gpu_memory_utilization=0.7, + num_scheduler_steps=num_scheduler_steps, + enable_chunked_prefill=enable_chunked_prefill, + enforce_eager=enforce_eager, + ) engine: LLMEngine = runner.model.llm_engine # In multi-step + chunked-prefill there is no separate single prompt step. # What is scheduled will run for num_scheduler_steps always. - num_prompt_steps = num_scheduler_steps \ - if is_multi_step_chunked_prefill else 1 + num_prompt_steps = num_scheduler_steps if is_multi_step_chunked_prefill else 1 num_output_tokens_list = [4, 8, 12, 15, 16, 17] @@ -50,10 +51,12 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, prompt_len = 10 for req_idx, num_output_tokens in enumerate(num_output_tokens_list): - seq, seq_group = create_dummy_prompt(request_id=str(req_idx), - prompt_length=prompt_len, - min_tokens=num_output_tokens, - max_tokens=num_output_tokens) + seq, seq_group = create_dummy_prompt( + request_id=str(req_idx), + prompt_length=prompt_len, + min_tokens=num_output_tokens, + max_tokens=num_output_tokens, + ) add_seq_group_to_engine(engine, seq_group) assert seq.data.get_num_computed_tokens() == 0 @@ -65,19 +68,19 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, if not seq.is_finished(): prompt_num_computed_tokens = seq.data.get_num_computed_tokens() # Test correctness of num_computed_tokens after the prompt steps - assert prompt_num_computed_tokens == \ - prompt_len + num_prompt_steps - 1 + assert prompt_num_computed_tokens == prompt_len + num_prompt_steps - 1 decode_step_counter = 0 while not seq.is_finished(): # Test correctness of num_computed_tokens after the decode steps - assert seq.data.get_num_computed_tokens( - ) == prompt_num_computed_tokens + decode_step_counter + assert ( + seq.data.get_num_computed_tokens() + == prompt_num_computed_tokens + decode_step_counter + ) for _ in range(num_scheduler_steps): # decode step engine.step() decode_step_counter += 1 # Test correctness of num_computed_tokens after the sequence finish. - assert seq.data.get_num_computed_tokens( - ) == prompt_len + num_output_tokens - 1 + assert seq.data.get_num_computed_tokens() == prompt_len + num_output_tokens - 1 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 591e1780c11c..731d1be00dfd 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -16,9 +16,14 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup, SequenceStatus -from .utils import (append_new_token, append_new_token_seq, - append_new_token_seq_group, create_dummy_prompt, - get_sequence_groups, schedule_and_update_computed_tokens) +from .utils import ( + append_new_token, + append_new_token_seq, + append_new_token_seq_group, + create_dummy_prompt, + get_sequence_groups, + schedule_and_update_computed_tokens, +) def test_scheduler_add_seq_group(): @@ -37,9 +42,7 @@ def test_scheduler_add_seq_group(): # Add seq group to scheduler. num_seq_group = 4 for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - block_size, - block_size=block_size) + _, seq_group = create_dummy_prompt(str(i), block_size, block_size=block_size) scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 @@ -89,9 +92,9 @@ def test_scheduler_schedule_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=block_size, block_size=block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -100,8 +103,11 @@ def test_scheduler_schedule_simple(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == num_seq_group append_new_token(out, 1) @@ -109,8 +115,11 @@ def test_scheduler_schedule_simple(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == num_seq_group append_new_token(out, 1) @@ -164,12 +173,8 @@ def test_scheduler_schedule_preempt_abort(): scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", - block_size, - block_size=block_size) - seq_b, seq_group_b = create_dummy_prompt("2", - block_size, - block_size=block_size) + seq_a, seq_group_a = create_dummy_prompt("1", block_size, block_size=block_size) + seq_b, seq_group_b = create_dummy_prompt("2", block_size, block_size=block_size) scheduler.add_seq_group(seq_group_a) scheduler.add_seq_group(seq_group_b) @@ -177,8 +182,11 @@ def test_scheduler_schedule_preempt_abort(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_a, seq_group_b] assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == 2 assert scheduler.get_num_unfinished_seq_groups() == 2 @@ -190,8 +198,11 @@ def test_scheduler_schedule_preempt_abort(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_a] assert out.num_batched_tokens == 1 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 2 assert out.preempted == 1 @@ -201,8 +212,11 @@ def test_scheduler_schedule_preempt_abort(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert get_sequence_groups(out) == [seq_group_b] assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 1 @@ -226,9 +240,9 @@ def test_scheduler_max_seqs(): all_seq_groups: list[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=block_size, block_size=block_size + ) all_seq_groups.append(seq_group) # Append 1 seq group @@ -270,33 +284,33 @@ def test_scheduler_delay_factor(): scheduler = Scheduler(scheduler_config, cache_config, None) # schedule first prompt - seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size, - block_size=block_size) + seq_group_meta, seq_group = create_dummy_prompt( + "0", prompt_length=block_size, block_size=block_size + ) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '0' + assert seq_group_meta[0].request_id == "0" append_new_token(out, 1) # wait for a second before scheduling next prompt time.sleep(1) - seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size, - block_size=block_size) + seq_group_meta, seq_group = create_dummy_prompt( + "1", prompt_length=block_size, block_size=block_size + ) scheduler.add_seq_group(seq_group) # second prompt should *not* be scheduled seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups == 0 - assert seq_group_meta[0].request_id == '0' + assert seq_group_meta[0].request_id == "0" append_new_token(out, 1) # wait for more than 0.5 second and try again time.sleep(0.6) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '1' + assert seq_group_meta[0].request_id == "1" append_new_token(out, 1) @@ -333,20 +347,20 @@ def initialize_scheduler( return scheduler -def create_token_budget(token_budget: int = 10000, - max_num_seqs: int = 10000) -> SchedulingBudget: +def create_token_budget( + token_budget: int = 10000, max_num_seqs: int = 10000 +) -> SchedulingBudget: return SchedulingBudget( token_budget=token_budget, max_num_seqs=max_num_seqs, ) -def add_token_budget(budget: SchedulingBudget, - num_batched_tokens: int = 0, - num_curr_seqs: int = 0): - mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1] - budget.add_num_batched_tokens(mock_seq_group.request_id, - num_batched_tokens) +def add_token_budget( + budget: SchedulingBudget, num_batched_tokens: int = 0, num_curr_seqs: int = 0 +): + mock_seq_group = create_dummy_prompt("10", prompt_length=60)[1] + budget.add_num_batched_tokens(mock_seq_group.request_id, num_batched_tokens) budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) @@ -356,9 +370,7 @@ def test_prefill_schedule_max_prompt_len(): """ block_size = 4 scheduler = initialize_scheduler(max_model_len=30, block_size=block_size) - _, seq_group = create_dummy_prompt("0", - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt("0", prompt_length=60, block_size=block_size) scheduler.add_seq_group(seq_group) budget = create_token_budget() output = scheduler._schedule_prefills(budget, None) @@ -375,14 +387,14 @@ def test_prefill_schedule_token_budget(): Test token budget respected. """ block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64 + ) budget = create_token_budget(token_budget=0) for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. @@ -405,14 +417,12 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=16, num_gpu_blocks=16 + ) budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, block_size=block_size) # Cannot schedule a prompt that doesn't fit the budget. scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) @@ -437,14 +447,14 @@ def test_prefill_schedule_max_seqs(): Test max seq respected. """ block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64 + ) budget = create_token_budget(max_num_seqs=2) for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -458,9 +468,7 @@ def test_prefill_schedule_max_seqs(): scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -477,20 +485,23 @@ def test_prefill_schedule_max_lora(): """ block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) + scheduler = initialize_scheduler( + lora_config=lora_config, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64, + ) budget = create_token_budget(token_budget=120) curr_loras: set[int] = set() for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=60, + block_size=block_size, + lora_request=LoRARequest( + lora_name=str(i), lora_int_id=i + 1, lora_path="abc" + ), + ) scheduler.add_seq_group(seq_group) # Add two more requests to verify lora is prioritized. # 0: LoRA, 1: LoRA, 2: regular, 3: regular @@ -498,9 +509,9 @@ def test_prefill_schedule_max_lora(): # If a request is not scheduled because it hits max lora, it is # prioritized. Verify that. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) output = scheduler._schedule_prefills(budget, curr_loras) @@ -529,14 +540,14 @@ def test_prefill_schedule_no_block_manager_capacity(): Test sequence cannot be scheduled due to block manager has no capacity. """ block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_gpu_blocks=128, - num_cpu_blocks=128) + scheduler = initialize_scheduler( + block_size=block_size, num_gpu_blocks=128, num_cpu_blocks=128 + ) budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER @@ -551,9 +562,9 @@ def test_prefill_schedule_no_block_manager_capacity(): scheduler = initialize_scheduler() budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER @@ -571,14 +582,14 @@ def test_decode_schedule_preempted(): Test decodes cannot be scheduled and preempted. """ block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64 + ) curr_loras = None for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._add_seq_group_to_running(seq_group) @@ -587,8 +598,7 @@ def test_decode_schedule_preempted(): def cannot_append_second_group(seq_group, num_lookahead_slots): return seq_group.request_id != "1" - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) + scheduler.block_manager.can_append_slots.side_effect = cannot_append_second_group # 1 cannot be scheduled, and the lowest priority (request 2) # should be preempted. 1 will also be preempted. @@ -615,12 +625,8 @@ def test_schedule_decode_blocks_to_copy_update(): Verify blocks_to_copy is updated. """ block_size = 4 - scheduler = initialize_scheduler(block_size=4, - num_cpu_blocks=16, - num_gpu_blocks=16) - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) + scheduler = initialize_scheduler(block_size=4, num_cpu_blocks=16, num_gpu_blocks=16) + _, seq_group = create_dummy_prompt("1", prompt_length=60, block_size=block_size) curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -648,20 +654,23 @@ def test_schedule_decode_blocks_to_copy_update(): def test_schedule_swapped_max_loras(): block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) + scheduler = initialize_scheduler( + lora_config=lora_config, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32, + ) curr_loras: set[int] = set() blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=60, + block_size=block_size, + lora_request=LoRARequest( + lora_name=str(i), lora_int_id=i + 1, lora_path="abc" + ), + ) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -680,15 +689,15 @@ def test_schedule_swapped_max_loras(): def test_schedule_swapped_cannot_swap_in(): block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32 + ) curr_loras = None blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -710,15 +719,15 @@ def test_schedule_swapped_cannot_swap_in(): def test_infeasible_swap(): block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32 + ) curr_loras = None blocks_to_swap_out: list[tuple[int, int]] = [] for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt( + str(i), prompt_length=60, block_size=block_size + ) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -741,13 +750,11 @@ def test_infeasible_swap(): def test_schedule_swapped_blocks_to_copy(): block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) + scheduler = initialize_scheduler( + block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32 + ) curr_loras = None - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) + _, seq_group = create_dummy_prompt("1", prompt_length=60, block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: list[tuple[int, int]] = [] @@ -840,26 +847,30 @@ def test_prefix_caching_aware_prefills(enable_prefix_caching): seqA_tokens = list(range(8)) num_shared_tokens = 4 - seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 12, 16)) # Shared prefix first 4. - seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 16, 20)) # Shared prefix first 4. - - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) + seqB_tokens = seqA_tokens[:num_shared_tokens] + list( + range(12, 16) + ) # Shared prefix first 4. + seqC_tokens = seqA_tokens[:num_shared_tokens] + list( + range(16, 20) + ) # Shared prefix first 4. + + seqA, seqA_group = create_dummy_prompt( + "0", prompt_tokens=seqA_tokens, block_size=block_size + ) + seqB, seqB_group = create_dummy_prompt( + "1", prompt_tokens=seqB_tokens, block_size=block_size + ) + seqC, seqC_group = create_dummy_prompt( + "2", prompt_tokens=seqC_tokens, block_size=block_size + ) # Schedule seqA prefill. scheduler.add_seq_group(seqA_group) metas, out, _ = scheduler.schedule() - assert (len(out.scheduled_seq_groups) == 1 - and out.scheduled_seq_groups[0].seq_group == seqA_group) + assert ( + len(out.scheduled_seq_groups) == 1 + and out.scheduled_seq_groups[0].seq_group == seqA_group + ) assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) # Schedule seqA decode. @@ -877,15 +888,18 @@ def test_prefix_caching_aware_prefills(enable_prefix_caching): if enable_prefix_caching: assert len(out.scheduled_seq_groups) == 2 - assert set([ - out.scheduled_seq_groups[0].seq_group, - out.scheduled_seq_groups[1].seq_group, - ]) == set([seqB_group, seqC_group]) + assert set( + [ + out.scheduled_seq_groups[0].seq_group, + out.scheduled_seq_groups[1].seq_group, + ] + ) == set([seqB_group, seqC_group]) assert len(metas) == 2 for meta in metas: assert meta.token_chunk_size == 8 - assert (len(meta.computed_block_nums) == num_shared_tokens // - block_size) # 1 Block for the 8 tokens. + assert ( + len(meta.computed_block_nums) == num_shared_tokens // block_size + ) # 1 Block for the 8 tokens. else: assert len(out.scheduled_seq_groups) == 1 assert len(metas) == 1 @@ -893,8 +907,7 @@ def test_prefix_caching_aware_prefills(enable_prefix_caching): assert len(metas[0].computed_block_nums) == 0 # No blocks computed. -def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( -): +def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(): """ This test verifies that we don't schedule new prefills if there's already a continuous prefill in progress even though the new prefills with shared @@ -931,12 +944,12 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( seqC_shared_prefix_len = 4 seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) + seqA, seqA_group = create_dummy_prompt( + "0", prompt_tokens=seqA_tokens, block_size=block_size + ) + seqB, seqB_group = create_dummy_prompt( + "1", prompt_tokens=seqB_tokens, block_size=block_size + ) # Chunked prefill seqA. scheduler.add_seq_group(seqA_group) @@ -955,27 +968,26 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( # both seqB and seqC can now be scheduled with seqA is over. # seqA is in decoding phase. append_new_token_seq(seqA, 999) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) + seqC, seqC_group = create_dummy_prompt( + "2", prompt_tokens=seqC_tokens, block_size=block_size + ) scheduler.add_seq_group(seqC_group) metas, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 3 metas = {meta.request_id: meta for meta in metas} assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode - assert (metas[seqB_group.request_id].token_chunk_size == 8 - ) # Fully cached prefill - assert ( - metas[seqC_group.request_id].token_chunk_size == 6 - ), "A partial prefix of C (4 tokens) should be prefilled, with the " + assert metas[seqB_group.request_id].token_chunk_size == 8 # Fully cached prefill + assert metas[seqC_group.request_id].token_chunk_size == 6, ( + "A partial prefix of C (4 tokens) should be prefilled, with the " + ) "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "then be rounded down to 2 tokens on block size, thus 6 tokens in total." def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): """ - Test that the scheduler does not schedule batches with prompt tokens and + Test that the scheduler does not schedule batches with prompt tokens and prompt embeddings co-mingled. """ block_size = 2 @@ -1005,10 +1017,12 @@ def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): seq_embeds.append(torch.rand(embedding_size)) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", + prompt_tokens=seq_tokens[i], + prompt_embeds=seq_embeds[i], + block_size=block_size, + ) for i in range(len(seq_tokens)) ] @@ -1017,24 +1031,29 @@ def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): unfinished_seq_groups = [ - seq_group for _, seq_group in seq_and_seq_groups + seq_group + for _, seq_group in seq_and_seq_groups if not seq_group.is_finished() ] _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) > 0 batch_is_prompt_embeds = out.scheduled_seq_groups[ - 0].seq_group.uses_prompt_embeds() + 0 + ].seq_group.uses_prompt_embeds() expected_scheduled_seq_groups = [ - seq_group for seq_group in unfinished_seq_groups + seq_group + for seq_group in unfinished_seq_groups if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds ] # We should have as many scheduled groups as possible, without mixing assert len(out.scheduled_seq_groups) == min( - max_seq_group, len(expected_scheduled_seq_groups)) - assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == - batch_is_prompt_embeds - for scheduled_seq_group in out.scheduled_seq_groups) + max_seq_group, len(expected_scheduled_seq_groups) + ) + assert all( + scheduled_seq_group.seq_group.uses_prompt_embeds() == batch_is_prompt_embeds + for scheduled_seq_group in out.scheduled_seq_groups + ) # Finish the scheduled groups for scheduled_seq_group in out.scheduled_seq_groups: @@ -1078,9 +1097,9 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_with_swapped.append([i] * seq_length) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_swapped[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", prompt_tokens=seq_tokens_with_swapped[i], block_size=block_size + ) for i in range(len(seq_tokens_with_swapped)) ] @@ -1090,43 +1109,46 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler._add_seq_group_to_swapped(seq_group) scheduler._schedule_swapped(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None # Prefill schedule don't have a space for another LoRA, so # we ignore this request for now. block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64, - enable_prefix_caching=True) + scheduler = initialize_scheduler( + lora_config=lora_config, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64, + enable_prefix_caching=True, + ) budget = create_token_budget(token_budget=120) num_seqs = 2 for i in range(num_seqs): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=seq_length, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=seq_length, + block_size=block_size, + lora_request=LoRARequest( + lora_name=str(i), lora_int_id=i + 1, lora_path="abc" + ), + ) scheduler.add_seq_group(seq_group) scheduler._schedule_prefills(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None # Priority preemption schedule scheduler._schedule_priority_preemption(budget) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None # Prefill scheduler does not schedule batches with prompt tokens and @@ -1152,10 +1174,12 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_embeds.append(torch.rand(embedding_size)) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_embedding[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", + prompt_tokens=seq_tokens_with_embedding[i], + prompt_embeds=seq_embeds[i], + block_size=block_size, + ) for i in range(len(seq_tokens_with_embedding)) ] @@ -1163,9 +1187,9 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None # Prefill scheduler budget num_batched_tokens @@ -1189,9 +1213,9 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_prefill_budget.append([i] * seq_length) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", prompt_tokens=seq_tokens_prefill_budget[i], block_size=block_size + ) for i in range(len(seq_tokens_prefill_budget)) ] @@ -1199,9 +1223,9 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(2)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 2 + ) assert seq_id_to_num_tokens_computed is None # Budget can not schedule in waiting @@ -1225,9 +1249,11 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_prefill_budget_waiting.append(list(range(seq_length))) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget_waiting[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", + prompt_tokens=seq_tokens_prefill_budget_waiting[i], + block_size=block_size, + ) for i in range(len(seq_tokens_prefill_budget_waiting)) ] @@ -1235,9 +1261,9 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None # Sequence num_new_tokens > prompt_limit marked FINISHED_IGNORED @@ -1256,16 +1282,16 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_prompt_limit: list[list[int]] = [] seq_tokens_prompt_limit.append(list(range(seq_length))) seq_and_seq_groups = [ - create_dummy_prompt("0", - prompt_tokens=seq_tokens_prompt_limit[0], - block_size=block_size) + create_dummy_prompt( + "0", prompt_tokens=seq_tokens_prompt_limit[0], block_size=block_size + ) ] for _, seq_group in seq_and_seq_groups: scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 0 + ) assert seq_id_to_num_tokens_computed is None # Budget can not allocate, AllocStatus is NEVER marked FINISHED_IGNORED @@ -1287,9 +1313,9 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_never.append(list(range(seq_length))) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_never[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", prompt_tokens=seq_tokens_never[i], block_size=block_size + ) for i in range(len(seq_tokens_never)) ] @@ -1297,9 +1323,9 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 0 + ) assert seq_id_to_num_tokens_computed is None # Budget can not allocate, AllocStatus is LATER @@ -1321,9 +1347,9 @@ def test_remove_seq_from_computed_blocks_tracker(): seq_tokens_later.append(list(range(seq_length))) seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_later[i], - block_size=block_size) + create_dummy_prompt( + f"{i}", prompt_tokens=seq_tokens_later[i], block_size=block_size + ) for i in range(len(seq_tokens_later)) ] @@ -1331,7 +1357,7 @@ def test_remove_seq_from_computed_blocks_tracker(): scheduler.add_seq_group(seq_group) scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) + seq_id_to_num_tokens_computed = scheduler.block_manager._computed_blocks_tracker._seq_id_to_num_tokens_computed.get( + 1 + ) assert seq_id_to_num_tokens_computed is None diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 20cc083ec8db..af93fdf4e074 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -7,12 +7,16 @@ from vllm.core.scheduler import Scheduler from vllm.sequence import SequenceGroup -from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, - get_sequence_groups, schedule_and_update_computed_tokens) +from .utils import ( + append_new_token, + create_dummy_prompt_encoder_decoder, + get_sequence_groups, + schedule_and_update_computed_tokens, +) def test_scheduler_schedule_simple_encoder_decoder(): - ''' + """ Test basic scheduler functionality in the context of an encoder/decoder model. Focus on testing enc/dec-specific functionality sense tests already @@ -32,7 +36,7 @@ def test_scheduler_schedule_simple_encoder_decoder(): * Abort scheduled seq groups * Assert that aborted seq groups no longer appear in cross-attention block table - ''' + """ block_size = 4 num_seq_group = 4 @@ -55,7 +59,8 @@ def test_scheduler_schedule_simple_encoder_decoder(): req_id = str(i) req_id_list.append(req_id) _, _, seq_group = create_dummy_prompt_encoder_decoder( - req_id, block_size, block_size, block_size) + req_id, block_size, block_size, block_size + ) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -64,15 +69,22 @@ def test_scheduler_schedule_simple_encoder_decoder(): seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) # - Verify that sequence group cross-attention block tables are # registered with the block manager - assert all([(req_id in scheduler.block_manager.cross_block_tables) - for req_id in req_id_list]) + assert all( + [ + (req_id in scheduler.block_manager.cross_block_tables) + for req_id in req_id_list + ] + ) # - Validate sequence-group status assert set(get_sequence_groups(out)) == set(running) # - Validate number of batched tokens assert out.num_batched_tokens == num_tokens # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) # - Validate all seq groups were scheduled assert len(seq_group_meta_list) == num_seq_group append_new_token(out, 1) @@ -81,18 +93,25 @@ def test_scheduler_schedule_simple_encoder_decoder(): seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) # - Verify that sequence group metadata includes encoder attention # and cross-attention metadata - assert all([ - not ((seq_group_meta.encoder_seq_data is None) or - (seq_group_meta.cross_block_table is None)) - for seq_group_meta in seq_group_meta_list - ]) + assert all( + [ + not ( + (seq_group_meta.encoder_seq_data is None) + or (seq_group_meta.cross_block_table is None) + ) + for seq_group_meta in seq_group_meta_list + ] + ) # - Validate sequence-group status assert set(get_sequence_groups(out)) == set(running) # - Validate there is one batched token per seq group assert out.num_batched_tokens == num_seq_group # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) + assert ( + not out.blocks_to_copy + and not out.blocks_to_swap_in + and not out.blocks_to_swap_out + ) # - Validate that all seq groups were scheduled assert len(seq_group_meta_list) == num_seq_group append_new_token(out, 1) diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 8281298d6634..75e4a37dda69 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -15,22 +15,20 @@ def test_msgspec_serialization(): execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=num_lookahead_slots, - running_queue_size=4) + running_queue_size=4, + ) encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) + decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, dec_hook=decode_hook) req = decoder.decode(encoder.encode(execute_model_req)) expected = execute_model_req.seq_group_metadata_list actual = req.seq_group_metadata_list - assert (len(expected) == len(actual)) + assert len(expected) == len(actual) expected = expected[0] actual = actual[0] assert expected.block_tables == actual.block_tables assert expected.is_prompt == actual.is_prompt assert expected.request_id == actual.request_id - assert (expected.seq_data[0].prompt_token_ids == - actual.seq_data[0].prompt_token_ids) - assert (expected.seq_data[0].output_token_ids == - actual.seq_data[0].output_token_ids) + assert expected.seq_data[0].prompt_token_ids == actual.seq_data[0].prompt_token_ids + assert expected.seq_data[0].output_token_ids == actual.seq_data[0].output_token_ids diff --git a/tests/core/utils.py b/tests/core/utils.py index b746c1786464..124c080b017e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -12,8 +12,7 @@ from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest -from vllm.sequence import (Logprob, Sequence, SequenceGroup, - SequenceGroupMetadata) +from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceGroupMetadata def create_dummy_prompt( @@ -35,10 +34,11 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - inputs = token_inputs( - prompt_token_ids=prompt_tokens, - prompt=prompt_str) if prompt_embeds is None else embeds_inputs( - prompt_embeds=prompt_embeds) + inputs = ( + token_inputs(prompt_token_ids=prompt_tokens, prompt=prompt_str) + if prompt_embeds is None + else embeds_inputs(prompt_embeds=prompt_embeds) + ) prompt = Sequence( int(request_id), inputs=inputs, @@ -48,26 +48,29 @@ def create_dummy_prompt( request_id=request_id, seqs=[prompt], arrival_time=time.time(), - sampling_params=SamplingParams(max_tokens=max_tokens, - min_tokens=min_tokens), + sampling_params=SamplingParams(max_tokens=max_tokens, min_tokens=min_tokens), lora_request=lora_request, ) return prompt, seq_group -def create_dummy_lora_sequence(request_id: int, token_ids: list[int], - block_size: int, lora_int_id: int) -> Sequence: - return Sequence(seq_id=request_id, - inputs=token_inputs(token_ids), - block_size=block_size, - lora_request=LoRARequest(lora_name="dummy", - lora_path="/dummy", - lora_int_id=lora_int_id)) +def create_dummy_lora_sequence( + request_id: int, token_ids: list[int], block_size: int, lora_int_id: int +) -> Sequence: + return Sequence( + seq_id=request_id, + inputs=token_inputs(token_ids), + block_size=block_size, + lora_request=LoRARequest( + lora_name="dummy", lora_path="/dummy", lora_int_id=lora_int_id + ), + ) -def create_dummy_sequence(request_id: int, token_ids: list[int], - block_size: int) -> Sequence: +def create_dummy_sequence( + request_id: int, token_ids: list[int], block_size: int +) -> Sequence: return Sequence( seq_id=request_id, inputs=token_inputs(token_ids), @@ -94,36 +97,36 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) inputs: EncoderDecoderInputs = { - "decoder": token_inputs(decoder_prompt_tokens, - prompt=decoder_prompt_str), - "encoder": token_inputs(encoder_prompt_tokens, - prompt=encoder_prompt_str), + "decoder": token_inputs(decoder_prompt_tokens, prompt=decoder_prompt_str), + "encoder": token_inputs(encoder_prompt_tokens, prompt=encoder_prompt_str), } - decoder_prompt = Sequence(int(request_id), - inputs=inputs["decoder"], - block_size=block_size) + decoder_prompt = Sequence( + int(request_id), inputs=inputs["decoder"], block_size=block_size + ) - encoder_prompt = Sequence(int(request_id), - inputs=inputs["encoder"], - block_size=block_size) + encoder_prompt = Sequence( + int(request_id), inputs=inputs["encoder"], block_size=block_size + ) - seq_group = SequenceGroup(request_id=request_id, - seqs=[decoder_prompt], - arrival_time=time.time(), - lora_request=lora_request, - encoder_seq=encoder_prompt) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[decoder_prompt], + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt, + ) return decoder_prompt, encoder_prompt, seq_group def create_seq_group( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - + seq_prompt_len: int = 1024, + seq_output_lens: GenericSequence[int] = (128,), + request_id: str = "0", + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None, +) -> SequenceGroup: assert len(seq_output_lens) > 0 if sampling_params is None: @@ -157,12 +160,12 @@ def create_seq_group( def create_seq_group_encoder_decoder( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - + seq_prompt_len: int = 1024, + seq_output_lens: GenericSequence[int] = (128,), + request_id: str = "0", + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None, +) -> SequenceGroup: assert len(seq_output_lens) > 0 if sampling_params is None: @@ -198,11 +201,13 @@ def create_seq_group_encoder_decoder( block_size=16, ) - return SequenceGroup(request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq) + return SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq, + ) def round_up_to_next_block(seq_len: int, block_size: int) -> int: @@ -250,7 +255,6 @@ def __init__(self, scheduler: Scheduler): self.call_history: dict[str, list[Any]] = defaultdict(list) def __getattr__(self, name: str) -> Any: - def wrapper(*args, **kwargs): result = getattr(self.scheduler_, name)(*args, **kwargs) self.call_history[name].append((args, kwargs, result)) @@ -259,6 +263,7 @@ def wrapper(*args, **kwargs): return wrapper def last_schedule_ret( - self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: + self, + ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: _, _, ret = self.call_history["schedule"][-1] return ret diff --git a/tests/cuda/test_cuda_context.py b/tests/cuda/test_cuda_context.py index f973b284b87e..6336f2112c66 100644 --- a/tests/cuda/test_cuda_context.py +++ b/tests/cuda/test_cuda_context.py @@ -13,7 +13,7 @@ def check_cuda_context(): """Check CUDA driver context status""" try: - cuda = ctypes.CDLL('libcuda.so') + cuda = ctypes.CDLL("libcuda.so") device = ctypes.c_int() result = cuda.cuCtxGetDevice(ctypes.byref(device)) return (True, device.value) if result == 0 else (False, None) @@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id): # New thread should have no CUDA context initially valid_before, device_before = check_cuda_context() if valid_before: - return False, \ - "CUDA context should not exist in new thread, " \ - f"got device {device_before}" + return ( + False, + "CUDA context should not exist in new thread, " + f"got device {device_before}", + ) # Test setting CUDA context current_platform.set_device(device_input) @@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id): if not valid_after: return False, "CUDA context should be valid after set_cuda_context" if device_id != expected_device_id: - return False, \ - f"Expected device {expected_device_id}, got {device_id}" + return False, f"Expected device {expected_device_id}, got {device_id}" return True, "Success" except Exception as e: @@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id): class TestSetCudaContext: """Test suite for the set_cuda_context function.""" - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") - @pytest.mark.parametrize(argnames="device_input,expected_device_id", - argvalues=[ - (0, 0), - (torch.device('cuda:0'), 0), - ('cuda:0', 0), - ], - ids=["int", "torch_device", "string"]) - def test_set_cuda_context_parametrized(self, device_input, - expected_device_id): + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") + @pytest.mark.parametrize( + argnames="device_input,expected_device_id", + argvalues=[ + (0, 0), + (torch.device("cuda:0"), 0), + ("cuda:0", 0), + ], + ids=["int", "torch_device", "string"], + ) + def test_set_cuda_context_parametrized(self, device_input, expected_device_id): """Test setting CUDA context in isolated threads.""" with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_cuda_test_in_thread, device_input, - expected_device_id) + future = executor.submit( + run_cuda_test_in_thread, device_input, expected_device_id + ) success, message = future.result(timeout=30) assert success, message - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") def test_set_cuda_context_invalid_device_type(self): """Test error handling for invalid device type.""" with pytest.raises(ValueError, match="Expected a cuda device"): - current_platform.set_device(torch.device('cpu')) + current_platform.set_device(torch.device("cpu")) if __name__ == "__main__": diff --git a/tests/detokenizer/test_disable_detokenization.py b/tests/detokenizer/test_disable_detokenization.py index ae06a985c7ec..a77626df5dc7 100644 --- a/tests/detokenizer/test_disable_detokenization.py +++ b/tests/detokenizer/test_disable_detokenization.py @@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str): prompt = ( "You are a helpful assistant. How do I build a car from cardboard and " "paper clips? Is there an easy to follow video tutorial available " - "online for free?") + "online for free?" + ) llm = LLM(model=model) - sampling_params = SamplingParams(max_tokens=10, - temperature=0.0, - detokenize=False) + sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False) - outputs_no_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] sampling_params.detokenize = True - outputs_with_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] - assert outputs_no_detokenization.text == '' - assert outputs_with_detokenization.text != '' - assert outputs_no_detokenization.token_ids == \ - outputs_with_detokenization.token_ids + assert outputs_no_detokenization.text == "" + assert outputs_with_detokenization.text != "" + assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids diff --git a/tests/detokenizer/test_stop_checker.py b/tests/detokenizer/test_stop_checker.py index bd221977224f..dbb62b305e3c 100644 --- a/tests/detokenizer/test_stop_checker.py +++ b/tests/detokenizer/test_stop_checker.py @@ -12,8 +12,7 @@ from vllm.sequence import Logprob, Sequence, SequenceStatus -def sequence_with_eos(text: str, eos_token: str, - eos_token_id: int) -> Sequence: +def sequence_with_eos(text: str, eos_token: str, eos_token_id: int) -> Sequence: """ Create a Sequence that ends with an EOS token. """ @@ -28,22 +27,29 @@ def sequence_with_eos(text: str, eos_token: str, offset = eos_token_id + 1 for i in range(offset, len(text) + offset): seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) - seq.append_token_id(token_id=eos_token_id, - logprobs={eos_token_id: Logprob(0.0)}) + seq.append_token_id(token_id=eos_token_id, logprobs={eos_token_id: Logprob(0.0)}) seq.status = SequenceStatus.RUNNING return seq -@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ - ("This text ends with EOS token", "", 2), -]) +@pytest.mark.parametrize( + ["text_wo_eos", "eos_token", "eos_token_id"], + [ + ("This text ends with EOS token", "", 2), + ], +) @pytest.mark.parametrize("ignore_eos", [True, False]) @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) @pytest.mark.skip_global_cleanup -def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, - ignore_eos: bool, include_stop_str_in_output: bool): +def test_stop_on_eos_token( + text_wo_eos: str, + eos_token: str, + eos_token_id: int, + ignore_eos: bool, + include_stop_str_in_output: bool, +): """ Test the behavior of the StopChecker's maybe_stop_sequence method when an EOS token is encountered. @@ -56,8 +62,9 @@ def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, tokenizer = MagicMock(spec=PreTrainedTokenizer) get_tokenizer_for_seq = MagicMock(return_value=tokenizer) - stop_checker = StopChecker(max_model_len=1024, - get_tokenizer_for_seq=get_tokenizer_for_seq) + stop_checker = StopChecker( + max_model_len=1024, get_tokenizer_for_seq=get_tokenizer_for_seq + ) seq = sequence_with_eos( text=text_wo_eos, @@ -70,7 +77,8 @@ def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, sampling_params = SamplingParams( min_tokens=1, ignore_eos=ignore_eos, - include_stop_str_in_output=include_stop_str_in_output) + include_stop_str_in_output=include_stop_str_in_output, + ) stop_checker.maybe_stop_sequence( seq=seq, diff --git a/tests/detokenizer/test_stop_reason.py b/tests/detokenizer/test_stop_reason.py index 9716f7d72a58..7415f96c2665 100644 --- a/tests/detokenizer/test_stop_reason.py +++ b/tests/detokenizer/test_stop_reason.py @@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts): llm = vllm_model.model # test stop token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop_token_ids=[stop_token_id])) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop_token_ids=[stop_token_id], + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == stop_token_id # test stop string - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop=".")) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="." + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == STOP_STR # test EOS token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS)) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "length" or ( - output.finish_reason == "stop" and output.stop_reason is None) + output.finish_reason == "stop" and output.stop_reason is None + ) diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index efe938a20c4f..d544bbbbc5db 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -11,12 +11,14 @@ MAX_TOKENS = 200 -def _test_stopping(llm: LLM, - expected_output: str, - expected_reason: Any, - stop: Optional[list[str]] = None, - stop_token_ids: Optional[list[int]] = None, - include_in_output: bool = False) -> None: +def _test_stopping( + llm: LLM, + expected_output: str, + expected_reason: Any, + stop: Optional[list[str]] = None, + stop_token_ids: Optional[list[int]] = None, + include_in_output: bool = False, +) -> None: output = llm.generate( "A story about vLLM:\n", SamplingParams( @@ -25,7 +27,8 @@ def _test_stopping(llm: LLM, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_in_output, - ))[0].outputs[0] + ), + )[0].outputs[0] assert output is not None assert output.text == expected_output @@ -37,17 +40,21 @@ def _set_async_mode(llm, is_async): def _stop_basic(llm): - _test_stopping(llm, - stop=["."], - include_in_output=False, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".", + ) - _test_stopping(llm, - stop=["."], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization.", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".", + ) def _stop_multi_tokens(llm): @@ -56,45 +63,54 @@ def _stop_multi_tokens(llm): stop=["group of peo", "short"], include_in_output=False, expected_output="VLLM is a 100% volunteer organization. We are a ", - expected_reason="group of peo") + expected_reason="group of peo", + ) _test_stopping( llm, stop=["group of peo", "short"], include_in_output=True, - expected_output= - "VLLM is a 100% volunteer organization. We are a group of peo", - expected_reason="group of peo") + expected_output="VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo", + ) def _stop_partial_token(llm): - _test_stopping(llm, - stop=["gani"], - include_in_output=False, - expected_output="VLLM is a 100% volunteer or", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani", + ) - _test_stopping(llm, - stop=["gani"], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organi", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani", + ) def _stop_token_id(llm): # token id 13013 => " organization" - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=False, - expected_output="VLLM is a 100% volunteer", - expected_reason=13013) - - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=13013) + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013, + ) + + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013, + ) @pytest.mark.skip_global_cleanup diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 666a715cc0da..efe6fee58f31 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -111,8 +111,7 @@ def __init__( self.last_seq = -1 self.decoder = msgspec.msgpack.Decoder(type=decode_type) - def receive_one(self, - timeout=1000) -> Union[tuple[int, SampleBatch], None]: + def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]: """Receive a single message with timeout""" if not self.sub.poll(timeout): return None @@ -135,8 +134,7 @@ def request_replay(self, start_seq: int, socket_idx: int = 0) -> None: self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big")) - def receive_replay(self, - socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: + def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: """Receive replayed messages from a specific replay socket""" if not self.replay_sockets: raise ValueError("Replay sockets not initialized") diff --git a/tests/distributed/test_ca_buffer_sharing.py b/tests/distributed/test_ca_buffer_sharing.py index e2de462612b4..1ddce64f8e61 100644 --- a/tests/distributed/test_ca_buffer_sharing.py +++ b/tests/distributed/test_ca_buffer_sharing.py @@ -12,7 +12,8 @@ from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa - CustomAllreduce) + CustomAllreduce, +) # create a cpu process group for communicating metadata (ipc handle) dist.init_process_group(backend="gloo") @@ -52,7 +53,8 @@ assert ord(host_data[i]) == byte_value, ( f"Rank {rank} failed" f" to verify buffer {p}. Expected {byte_value}, " - f"got {ord(host_data[i])}") + f"got {ord(host_data[i])}" + ) print(f"Rank {rank} verified all buffers") diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e2cb579e22dc..5def4f9c1316 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -13,10 +13,13 @@ import ray import torch -from vllm.distributed import (broadcast_tensor_dict, get_pp_group, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) +from vllm.distributed import ( + broadcast_tensor_dict, + get_pp_group, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter, +) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -36,12 +39,11 @@ def all_reduce_test_worker( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank % tp_size] @@ -50,28 +52,31 @@ def all_reduce_test_worker( @ray.remote(num_gpus=1, max_calls=1) -def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, - pp_size: int, rank: int, - distributed_init_port: str): +def reduce_scatter_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, +): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] index = rank % tp_size partition_size = num_elements // tp_size all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - expected = all_reduce[index * partition_size:(index + 1) * partition_size] + expected = all_reduce[index * partition_size : (index + 1) * partition_size] t = all_tensors[index] t = tensor_model_parallel_reduce_scatter(t, 0) torch.testing.assert_close(t, expected) @@ -91,8 +96,7 @@ def all_gather_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) total_size = 1 @@ -100,8 +104,10 @@ def all_gather_test_worker( total_size *= s for all_gather_dimension in range(num_dimensions): all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="cuda").reshape(tensor_size) * (r + 1) + torch.arange(total_size, dtype=torch.float32, device="cuda").reshape( + tensor_size + ) + * (r + 1) for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) @@ -124,8 +130,7 @@ def broadcast_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), @@ -133,10 +138,7 @@ def broadcast_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -165,8 +167,7 @@ def send_recv_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -175,10 +176,7 @@ def send_recv_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -210,8 +208,7 @@ def send_recv_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) size = 64 test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") @@ -226,13 +223,14 @@ def send_recv_test_worker( torch.testing.assert_close(test_tensor, recv_tensor) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) @pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", [ - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker], +) def test_multi_process_tensor_parallel( monkeypatch: pytest.MonkeyPatch, tp_size: int, @@ -241,11 +239,13 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( - "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) + "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker] +) def test_multi_process_pipeline_parallel( monkeypatch: pytest.MonkeyPatch, pp_size: int, @@ -254,15 +254,21 @@ def test_multi_process_pipeline_parallel( multi_process_parallel(monkeypatch, 1, pp_size, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) -@pytest.mark.parametrize("test_target", [ - send_recv_test_worker, send_recv_tensor_dict_test_worker, - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [ + send_recv_test_worker, + send_recv_tensor_dict_test_worker, + all_reduce_test_worker, + all_gather_test_worker, + broadcast_tensor_dict_test_worker, + ], +) def test_multi_process_tensor_parallel_pipeline_parallel( tp_size: int, pp_size: int, diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index fae49c41d5f8..68d37709c4af 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -8,13 +8,18 @@ import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) - -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_group, + get_tp_group, + graph_capture, +) + +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -34,8 +39,7 @@ def graph_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tensor_model_parallel_group().device_group @@ -61,18 +65,15 @@ def graph_allreduce( for dtype in [torch.float32, torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test @@ -97,8 +98,7 @@ def eager_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # we use the first group to communicate once # and the second group to communicate twice @@ -133,5 +133,4 @@ def test_custom_allreduce( world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_distributed_oot.py b/tests/distributed/test_distributed_oot.py index b93696e4be0e..ea7a88abda24 100644 --- a/tests/distributed/test_distributed_oot.py +++ b/tests/distributed/test_distributed_oot.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from ..entrypoints.openai.test_oot_registration import ( - run_and_test_dummy_opt_api_server) +from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server def test_distributed_oot(dummy_opt_path: str): diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index e47ccba99c81..79805a7cce53 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -10,10 +10,12 @@ def test_basic_rebalance(): """Test basic rebalancing functionality""" # Example from https://github.com/deepseek-ai/eplb - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_layers = weight.shape[0] num_replicas = 16 @@ -21,45 +23,49 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify output shapes assert phy2log.shape == ( 2, 16, ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" - assert (log2phy.shape[0] == 2 - ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" - assert ( - log2phy.shape[1] == 12 - ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert log2phy.shape[0] == 2, ( + f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + ) + assert log2phy.shape[1] == 12, ( + f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + ) assert logcnt.shape == ( 2, 12, ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" # Verify physical to logical expert mapping range is correct - assert torch.all(phy2log >= 0) and torch.all( - phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), ( + "Physical to logical mapping should be in range [0, 12)" + ) # Verify expert count reasonableness - assert torch.all( - logcnt >= 1), "Each logical expert should have at least 1 replica" - assert ( - torch.sum(logcnt, dim=1).sum() == num_replicas * - num_layers), f"Total replicas should be {num_replicas * num_layers}" + assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica" + assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, ( + f"Total replicas should be {num_replicas * num_layers}" + ) # Verify expected output - expected_phy2log = torch.tensor([ - [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], - [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], - ]) + expected_phy2log = torch.tensor( + [ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ] + ) assert torch.all(phy2log == expected_phy2log) - expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], - [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + expected_logcnt = torch.tensor( + [[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]] + ) assert torch.all(logcnt == expected_logcnt) @@ -71,9 +77,9 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 4) @@ -93,19 +99,19 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 8) # With equal weights, each expert should have exactly one replica - assert torch.all( - logcnt == 1 - ), "With equal weights and no replication, " \ - "each expert should have exactly 1 replica" + assert torch.all(logcnt == 1), ( + "With equal weights and no replication, " + "each expert should have exactly 1 replica" + ) def test_extreme_weight_imbalance(): @@ -116,35 +122,37 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 12) assert logcnt.shape == (1, 8) # Expert with highest weight (index 0) should have more replicas - assert ( - logcnt[0, 0] - > logcnt[0, 1]), "Expert with highest weight should have more replicas" + assert logcnt[0, 0] > logcnt[0, 1], ( + "Expert with highest weight should have more replicas" + ) def test_multiple_layers(): """Test multiple layers case""" - weight = torch.tensor([ - [10, 20, 30, 40, 50, 60], # First layer - [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) - [25, 25, 25, 25, 25, 25], # Third layer (equal weights) - ]) + weight = torch.tensor( + [ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ] + ) num_replicas = 8 num_groups = 2 num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (3, 8) @@ -152,12 +160,12 @@ def test_multiple_layers(): # Verify expert allocation is reasonable for each layer for layer in range(3): - assert torch.all(phy2log[layer] >= 0) and torch.all( - phy2log[layer] < 6 - ), f"Layer {layer} physical to logical mapping" \ - "should be in range [0, 6)" - assert (torch.sum(logcnt[layer]) == num_replicas - ), f"Layer {layer} total replicas should be {num_replicas}" + assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), ( + f"Layer {layer} physical to logical mappingshould be in range [0, 6)" + ) + assert torch.sum(logcnt[layer]) == num_replicas, ( + f"Layer {layer} total replicas should be {num_replicas}" + ) def test_parameter_validation(): @@ -179,17 +187,19 @@ def test_parameter_validation(): def test_small_scale_hierarchical(): """Test small-scale hierarchical load balancing""" - weight = torch.tensor([ - [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts - ]) + weight = torch.tensor( + [ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ] + ) num_replicas = 12 num_groups = 4 # 4 groups, 2 experts each num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify basic constraints assert phy2log.shape == (1, 12) @@ -199,8 +209,9 @@ def test_small_scale_hierarchical(): # Expert with highest weight should have more replicas max_weight_expert = torch.argmax(weight[0]) - assert (logcnt[0, max_weight_expert] - >= 2), "Highest weight expert should have multiple replicas" + assert logcnt[0, max_weight_expert] >= 2, ( + "Highest weight expert should have multiple replicas" + ) def test_global_load_balance_fallback(): @@ -213,9 +224,9 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Should work normally, just using global load balancing strategy assert phy2log.shape == (1, 8) @@ -235,9 +246,9 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Function will convert to CPU internally, but should handle different # device inputs normally @@ -250,7 +261,8 @@ def test_additional_cases(): # Test case 1: Large-scale distributed setup weight1 = torch.tensor( - [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] + ) phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) assert phy2log1.shape == (1, 24) @@ -258,10 +270,12 @@ def test_additional_cases(): assert torch.sum(logcnt1) == 24 # Test case 2: Different weight distributions - weight2 = torch.tensor([ - [200, 150, 100, 50, 25, 12], # Decreasing weights - [12, 25, 50, 100, 150, 200], # Increasing weights - ]) + weight2 = torch.tensor( + [ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ] + ) phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) assert phy2log2.shape == (2, 10) @@ -274,19 +288,21 @@ def test_additional_cases(): if __name__ == "__main__": - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_replicas = 16 num_groups = 4 num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) print(phy2log) test_basic_rebalance() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index de9ed1eabbac..7ca3d3d27b56 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -9,11 +9,12 @@ import torch import torch.distributed -from vllm.distributed.eplb.rebalance_execute import ( - rearrange_expert_weights_inplace) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_tp_group, - init_distributed_environment) +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -22,13 +23,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -45,7 +46,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -60,20 +61,20 @@ def wrapped_fn(env): def create_expert_indices_with_redundancy( - num_layers: int, - num_logical_experts: int, - total_physical_experts: int, - redundancy_config: list[int], # redundancy for each logical expert + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert ) -> torch.Tensor: """ Create expert indices with redundancy. - + Args: num_layers: number of layers num_logical_experts: number of logical experts total_physical_experts: total number of physical experts redundancy_config: redundancy for each logical expert - + Returns: indices: Shape (num_layers, total_physical_experts) """ @@ -106,11 +107,11 @@ def create_expert_weights( ) -> list[list[torch.Tensor]]: """ Create fake expert weights tensor for testing. - + Use `arange` to generate predictable weights values, based on logical expert ID. All replicas of the same logical expert should have the same weights. - + Args: physical_to_logical_mapping: Shape (num_layers, num_local_experts) mapping[layer, physical_pos] = logical_expert_id @@ -120,27 +121,27 @@ def create_expert_weights( for layer in range(num_layers): layer_weights = [] for weight_idx, hidden_size in enumerate(hidden_sizes): - weight_tensor = torch.zeros(num_local_experts, - hidden_size, - device=device, - dtype=torch.float32) + weight_tensor = torch.zeros( + num_local_experts, hidden_size, device=device, dtype=torch.float32 + ) for local_expert in range(num_local_experts): # Get the logical expert ID for this physical expert global_pos = rank * num_local_experts + local_expert logical_expert_id = physical_to_logical_mapping[ - layer, global_pos].item() + layer, global_pos + ].item() # Generate weights based on logical expert ID # (so that all replicas of the same logical expert have the # same weights) - base_value = (logical_expert_id * 1000 + layer * 100 + - weight_idx * 10) - weight_tensor[local_expert] = torch.arange(base_value, - base_value + - hidden_size, - device=device, - dtype=torch.float32) + base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10 + weight_tensor[local_expert] = torch.arange( + base_value, + base_value + hidden_size, + device=device, + dtype=torch.float32, + ) layer_weights.append(weight_tensor) expert_weights.append(layer_weights) @@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle( # Check if the weights are correct actual_weights = weight_tensor[local_expert] - expected_base = (expected_logical_expert * 1000 + layer * 100 + - weight_idx * 10) - expected_weights = torch.arange(expected_base, - expected_base + hidden_size, - device=actual_weights.device, - dtype=actual_weights.dtype) + expected_base = ( + expected_logical_expert * 1000 + layer * 100 + weight_idx * 10 + ) + expected_weights = torch.arange( + expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype, + ) torch.testing.assert_close( actual_weights, @@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle( msg=f"Layer {layer}, weight {weight_idx}," f"local expert {local_expert}: " f"weights do not match. " - f"Expected logical expert {expected_logical_expert}") + f"Expected logical expert {expected_logical_expert}", + ) def verify_redundant_experts_have_same_weights( @@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights( total_physical_experts, hidden_size, device=expert_weights[layer][weight_idx].device, - dtype=expert_weights[layer][weight_idx].dtype) + dtype=expert_weights[layer][weight_idx].dtype, + ) # Use all_gather to collect expert weights from current node # expert_weights[layer][weight_idx] shape: # [num_local_experts, hidden_size] local_weights = expert_weights[layer][ - weight_idx] # [num_local_experts, hidden_size] + weight_idx + ] # [num_local_experts, hidden_size] # Split tensor along dim 0 into a list for all_gather - gathered_weights_list = torch.chunk(gathered_weights, - world_size, - dim=0) + gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0) torch.distributed.all_gather( # Output list: each element corresponds to one rank's weights list(gathered_weights_list), - local_weights # Input: current rank's local weights + local_weights, # Input: current rank's local weights ) all_weights.append(gathered_weights) @@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights( msg=f"Layer {layer}, weight {weight_idx}," f"logical expert {logical_expert_id}: " f"Physical expert {physical_pos} has different weights" - f"than expected") + f"than expected", + ) @pytest.mark.parametrize( @@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights( # 4 GPU, 8 experts per GPU # 16 logical experts, 32 physical experts, 16 redundant experts (4, 8, 8, 16), - ]) -def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, - num_local_experts, - num_logical_experts): + ], +) +def test_rearrange_expert_weights_with_redundancy( + world_size, num_layers, num_local_experts, num_logical_experts +): """Test the functionality of rearranging expert weights with redundancy.""" if torch.cuda.device_count() < world_size: @@ -304,8 +311,8 @@ def worker_fn(): # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -316,8 +323,9 @@ def worker_fn(): hidden_sizes = [32, 64] # Two different weight matrices # Create old expert indices (with redundancy) - redundancy_config = create_redundancy_config(num_logical_experts, - total_physical_experts) + redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( num_layers, @@ -328,7 +336,8 @@ def worker_fn(): # Create new expert indices (with redundancy) new_redundancy_config = create_redundancy_config( - num_logical_experts, total_physical_experts) + num_logical_experts, total_physical_experts + ) new_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, @@ -337,9 +346,9 @@ def worker_fn(): ) # Create expert weights - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Execute weight rearrangement rearrange_expert_weights_inplace( @@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -401,12 +410,12 @@ def worker_fn(): # Same indices - no change indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - redundancy_config) + num_layers, num_logical_experts, total_physical_experts, redundancy_config + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices + ) # Save original weights original_weights = [] @@ -422,7 +431,8 @@ def worker_fn(): indices, # Same indices expert_weights, ep_group, - is_profile=False) + is_profile=False, + ) # Verify that the weights have not changed for layer in range(num_layers): @@ -430,8 +440,8 @@ def worker_fn(): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg=f"Layer {layer}, weight {weight_idx} should remain " - f"unchanged") + msg=f"Layer {layer}, weight {weight_idx} should remain unchanged", + ) distributed_run(worker_fn, world_size) @@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -460,21 +470,23 @@ def worker_fn(): hidden_sizes = [32] # Create different index distributions - old_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) - new_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) + old_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) + new_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - old_redundancy) + num_layers, num_logical_experts, total_physical_experts, old_redundancy + ) new_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - new_redundancy) + num_layers, num_logical_experts, total_physical_experts, new_redundancy + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Save original weights original_weights = [] @@ -490,7 +502,7 @@ def worker_fn(): new_indices, expert_weights, ep_group, - is_profile=True # Profile mode + is_profile=True, # Profile mode ) # In profile mode, the weights should remain unchanged @@ -499,6 +511,7 @@ def worker_fn(): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg="In profile mode, the weights should remain unchanged") + msg="In profile mode, the weights should remain unchanged", + ) distributed_run(worker_fn, world_size) diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index 8be9ee0a1889..f06f6771a4a0 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -6,24 +6,29 @@ import msgspec import pytest -from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, - NullEventPublisher) +from vllm.distributed.kv_events import ( + EventBatch, + EventPublisherFactory, + NullEventPublisher, +) DP_RANK = 0 class EventSample( - msgspec.Struct, - tag=True, # type: ignore - array_like=True # type: ignore + msgspec.Struct, + tag=True, # type: ignore + array_like=True, # type: ignore ): """Test event for publisher testing""" + id: int value: str class SampleBatch(EventBatch): """Test event batch for publisher testing""" + events: list[EventSample] @@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber): seq, received = result assert seq == 0, "Sequence number mismatch" - assert received.ts == pytest.approx(test_batch.ts, - abs=0.1), ("Timestamp mismatch") - assert len(received.events) == len( - test_batch.events), ("Number of events mismatch") + assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch" + assert len(received.events) == len(test_batch.events), "Number of events mismatch" for i, event in enumerate(received.events): assert event.id == i, "Event id mismatch" @@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber): assert len(replayed) > 0, "No replayed messages received" seqs = [seq for seq, _ in replayed] assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" - assert seqs == list(range(min(seqs), - max(seqs) + - 1)), ("Replayed messages not consecutive") + assert seqs == list(range(min(seqs), max(seqs) + 1)), ( + "Replayed messages not consecutive" + ) def test_buffer_limit(publisher, subscriber, publisher_config): @@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config): pub = EventPublisherFactory.create(publisher_config, DP_RANK) from .conftest import MockSubscriber + sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") @@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config): foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] assert all(msg is not None for msg in foo_received), ( - "Subscriber with matching topic should receive messages") + "Subscriber with matching topic should receive messages" + ) bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] assert all(msg is None for msg in bar_received), ( - "Subscriber with non-matching topic should receive no messages") + "Subscriber with non-matching topic should receive no messages" + ) finally: pub.shutdown() sub_foo.close() @@ -178,8 +184,7 @@ def publish_events(): publisher_thread.join() - assert len(received) >= num_batches * 0.9, ( - "We should have received most messages") + assert len(received) >= num_batches * 0.9, "We should have received most messages" seqs = [seq for seq, _ in received] assert sorted(seqs) == seqs, "Sequence numbers should be in order" @@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config): # For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558 expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port expected_endpoint_1 = base_endpoint.replace( - ":5557", ":5558") # rank 1 gets port + 1 + ":5557", ":5558" + ) # rank 1 gets port + 1 else: # For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1 expected_endpoint_0 = base_endpoint # rank 0 gets base expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1 from .conftest import MockSubscriber + sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic) sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic) @@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config): # Verify DP rank tagging assert received_0.data_parallel_rank == 0, ( - f"Expected DP rank 0, got {received_0.data_parallel_rank}") + f"Expected DP rank 0, got {received_0.data_parallel_rank}" + ) assert received_1.data_parallel_rank == 1, ( - f"Expected DP rank 1, got {received_1.data_parallel_rank}") + f"Expected DP rank 1, got {received_1.data_parallel_rank}" + ) # Verify event content is correct - assert len( - received_0.events) == 2, "Wrong number of events from rank 0" - assert len( - received_1.events) == 3, "Wrong number of events from rank 1" + assert len(received_0.events) == 2, "Wrong number of events from rank 0" + assert len(received_1.events) == 3, "Wrong number of events from rank 1" finally: pub_0.shutdown() diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index f641bf160414..a010e5b0f709 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -46,28 +46,24 @@ def detailed( ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True + ), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False + ), ], distributed_backends=["mp", "ray"], task=task, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) @staticmethod @@ -82,16 +78,16 @@ def fast( ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), ], distributed_backends=["mp"], task=task, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) def iter_params(self, model_name: str): @@ -99,8 +95,7 @@ def iter_params(self, model_name: str): for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: - yield (model_name, parallel_setup, distributed_backend, - self.task, opts) + yield (model_name, parallel_setup, distributed_backend, self.task, opts) # NOTE: You can adjust tp_base locally to fit the model in GPU diff --git a/tests/distributed/test_multi_node_assignment.py b/tests/distributed/test_multi_node_assignment.py index ef17a51fff0e..8d818edbb3bd 100644 --- a/tests/distributed/test_multi_node_assignment.py +++ b/tests/distributed/test_multi_node_assignment.py @@ -24,14 +24,13 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.skipif(not VLLM_MULTI_NODE, - reason="Need at least 2 nodes to run the test.") +@pytest.mark.skipif( + not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test." +) def test_multi_node_assignment() -> None: - # NOTE: important to keep this class definition here # to let ray use cloudpickle to serialize it. class Actor: - def get_ip(self): return get_ip() @@ -41,8 +40,7 @@ def get_ip(self): current_ip = get_ip() workers = [] - for bundle_id, bundle in enumerate( - config.placement_group.bundle_specs): + for bundle_id, bundle in enumerate(config.placement_group.bundle_specs): if not bundle.get("GPU", 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py index e3c36ef5ef37..b48c025aa1a2 100644 --- a/tests/distributed/test_node_count.py +++ b/tests/distributed/test_node_count.py @@ -32,12 +32,15 @@ # Expected node count based on environment variable) expected = int(os.environ.get("NUM_NODES", "1")) - assert test_result == expected, \ - f"Expected {expected} nodes, got {test_result}" + assert test_result == expected, f"Expected {expected} nodes, got {test_result}" if pg == dist.group.WORLD: - print(f"Node count test passed! Got {test_result} nodes " - f"when using torch distributed!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using torch distributed!" + ) else: - print(f"Node count test passed! Got {test_result} nodes " - f"when using StatelessProcessGroup!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using StatelessProcessGroup!" + ) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 926a33c949eb..5bc71a0bbb8c 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -7,6 +7,7 @@ all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -35,7 +36,7 @@ def use_v0_only(monkeypatch): weights. Once we enable V1 by default for PP, we can remove this. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") class ParallelSetup(NamedTuple): @@ -68,7 +69,8 @@ def __post_init__(self): raise ValueError( f"Length mismatch: distributed_backends " f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") + f"vllm_major_versions ({len(self.vllm_major_versions)})" + ) @staticmethod def detailed( @@ -81,32 +83,43 @@ def detailed( ): return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup( + tp_size=tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=False, + ), + ParallelSetup( + tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=False, + chunked_prefill=True, + ), + ParallelSetup( + tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=True, + chunked_prefill=False, + ), + ParallelSetup( + tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=True, + ), + ParallelSetup( + tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False, + ), ], distributed_backends=["mp", "mp", "ray", "ray"], vllm_major_versions=["0", "1", "0", "1"], task=task, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -120,26 +133,36 @@ def fast( ): return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup( + tp_size=tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False, + ), ], distributed_backends=["mp"], vllm_major_versions=["0"], task=task, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.task, opts) + for backend, vllm_major_version in zip( + self.distributed_backends, self.vllm_major_versions + ): + yield ( + model_id, + parallel_setup, + backend, + vllm_major_version, + self.task, + opts, + ) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU @@ -317,8 +340,10 @@ def _compare_tp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -348,8 +373,7 @@ def _compare_tp( specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill testing_ray_compiled_graph = False - if distributed_backend == "ray" and (vllm_major_version == "1" - or specific_case): + if distributed_backend == "ray" and (vllm_major_version == "1" or specific_case): # For V1, test Ray Compiled Graph for all the tests # For V0, test Ray Compiled Graph for a subset of the tests pp_env = { @@ -398,12 +422,7 @@ def _compare_tp( ] try: - compare_two_settings(model_id, - pp_args, - tp_args, - pp_env, - tp_env, - method=method) + compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method) except Exception: if testing_ray_compiled_graph and vllm_major_version == "0": # Ray Compiled Graph tests are flaky for V0, @@ -414,11 +433,19 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "task", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "task", + "test_options", + ), [ - params for model_id, settings in TEXT_GENERATION_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -431,23 +458,33 @@ def test_tp_language_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - task, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "task", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "task", + "test_options", + ), [ - params for model_id, settings in EMBEDDING_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in EMBEDDING_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -460,23 +497,33 @@ def test_tp_language_embedding( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - task, - test_options, - num_gpus_available, - method="encode", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="encode", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "task", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "task", + "test_options", + ), [ - params for model_id, settings in MULTIMODAL_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in MULTIMODAL_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -489,12 +536,14 @@ def test_tp_multimodal_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - task, - test_options, - num_gpus_available, - method="generate", - is_multimodal=True) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=True, + ) diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py index 69ceedd345a8..4df6f43970d7 100644 --- a/tests/distributed/test_pipeline_partition.py +++ b/tests/distributed/test_pipeline_partition.py @@ -9,7 +9,6 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: def _verify(partition_str, num_layers, pp_size, goldens): @@ -57,7 +56,8 @@ def _verify(partition_str, num_layers, pp_size, goldens): (5, 3, 0, (0, 2)), (5, 3, 1, (2, 4)), (5, 3, 2, (4, 5)), - ]) + ], +) def test_uneven_auto_partition( num_hidden_layers: int, pp_size: int, diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index a027a9e37dd6..518b1bf76fd7 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -12,13 +12,19 @@ from typing_extensions import LiteralString -@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ - (2, "JackFram/llama-160m"), -]) -@pytest.mark.parametrize("ATTN_BACKEND", [ - "FLASH_ATTN", - "FLASHINFER", -]) +@pytest.mark.parametrize( + "PP_SIZE, MODEL_NAME", + [ + (2, "JackFram/llama-160m"), + ], +) +@pytest.mark.parametrize( + "ATTN_BACKEND", + [ + "FLASH_ATTN", + "FLASHINFER", + ], +) @create_new_process_for_each_test() def test_pp_cudagraph( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index abfad9ebfe7d..4bab709fb589 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,13 +9,15 @@ import torch import torch.distributed -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_world_group, graph_capture, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_world_group, + graph_capture, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -24,13 +26,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -47,7 +49,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -58,17 +60,18 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl(): distributed_run(worker_fn, 2) @@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), - torch.distributed.new_group(ranks=[2, 3], backend="gloo") + torch.distributed.new_group(ranks=[2, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) @@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly @@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` @@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') + a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}") torch.cuda.synchronize() with torch.cuda.graph(graph): a_out = pynccl_comm.all_reduce(a) @@ -148,84 +154,90 @@ def worker_fn_with_cudagraph(): @worker_fn_wrapper def all_gather_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - result = torch.zeros(num_elems * world_size, - dtype=torch.float32, - device=device) - - expected = torch.cat([ - torch.arange(num_elems, dtype=torch.float32) + r * num_elems - for r in range(world_size) - ]).to(device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device) + + expected = torch.cat( + [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) @worker_fn_wrapper def all_gatherv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sizes[rank] - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) - expected = torch.cat([ - torch.arange(sizes[r], dtype=torch.float32) + r * 100 - for r in range(world_size) - ]).to(device) + expected = torch.cat( + [ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gatherv(result, tensor, sizes=sizes) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gatherv(): distributed_run(all_gatherv_worker_fn, 2) @worker_fn_wrapper def reduce_scatter_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - assert (num_elems % world_size == 0) - result = torch.zeros(num_elems // world_size, - dtype=torch.float32, - device=device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + assert num_elems % world_size == 0 + result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk scattered_size = num_elems // world_size @@ -233,34 +245,37 @@ def reduce_scatter_worker_fn(): torch.arange(num_elems, dtype=torch.float32) + r * num_elems for r in range(world_size) ] - expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] - for tensor in all_tensors).to(device) + expected = sum( + tensor[rank * scattered_size : (rank + 1) * scattered_size] + for tensor in all_tensors + ).to(device) pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) @worker_fn_wrapper def reduce_scatterv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sum(sizes) - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk @@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn(): torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatterv(): distributed_run(reduce_scatterv_worker_fn, 2) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) @worker_fn_wrapper def send_recv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) if pynccl_comm.rank == 0: - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) else: - tensor = torch.empty(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_send_recv(): distributed_run(send_recv_worker_fn, 2) @@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 2], backend="gloo"), - torch.distributed.new_group(ranks=[1, 3], backend="gloo") + torch.distributed.new_group(ranks=[1, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) if torch.distributed.get_rank() == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) elif torch.distributed.get_rank() == 1: - tensor = 2 * torch.ones( - 16, 1024, 1024, dtype=torch.float32, device=device) + tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) else: - tensor = torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=device) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device) if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() @@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_broadcast(): distributed_run(broadcast_worker_fn, 4) @@ -366,19 +376,17 @@ def test_pynccl_broadcast(): def broadcast_worker_fn(): # Test broadcast for every root rank. # Essentially this is an all-gather operation. - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) recv_tensors = [ - torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=pynccl_comm.device) + torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) for i in range(pynccl_comm.world_size) ] - recv_tensors[pynccl_comm.rank] = torch.ones( - 16, 1024, 1024, dtype=torch.float32, - device=pynccl_comm.device) * pynccl_comm.rank + recv_tensors[pynccl_comm.rank] = ( + torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) + * pynccl_comm.rank + ) for i in range(pynccl_comm.world_size): pynccl_comm.broadcast(recv_tensors[i], src=i) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index a4added29144..579bc55dbba7 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -8,21 +8,24 @@ import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_group, + get_tp_group, + graph_capture, +) from vllm.platforms import current_platform -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) torch.manual_seed(42) random.seed(44) # Size over 8MB is sufficient for custom quick allreduce. -test_sizes = [ - random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8) -] +test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)] for i, v in enumerate(test_sizes): test_sizes[i] -= v % 8 @@ -39,8 +42,7 @@ def graph_quickreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tensor_model_parallel_group().device_group @@ -65,18 +67,15 @@ def graph_quickreduce( for sz in test_sizes: for dtype in [torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: - inp1 = torch.randint(1, - 23, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(-23, - 1, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + -23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for _ in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) dist.all_reduce(inp1, group=group) @@ -100,39 +99,42 @@ def eager_quickreduce( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # Size over 8MB is sufficient for custom quick allreduce. sz = 16 * 1024 * 1024 fa = get_tp_group().device_communicator.qr_comm - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.float16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.bfloat16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test quick allreduce for rocm") +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="only test quick allreduce for rocm" +) @pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) -def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size, test_target, - quant_mode): +def test_custom_quick_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pipeline_parallel_size, + test_target, + quant_mode, +): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 94ad8f4f1213..baf75fd48c63 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -22,15 +22,13 @@ dist.broadcast_object_list(recv, src=0) ip, port = recv - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: test_result = all(in_the_same_node_as(pg, source_rank=0)) expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" - assert test_result == expected, \ - f"Expected {expected}, got {test_result}" + assert test_result == expected, f"Expected {expected}, got {test_result}" if pg == dist.group.WORLD: print("Same node test passed! when using torch distributed!") else: diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index b2f6a8ab9dd3..67f960f294d4 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -7,6 +7,7 @@ all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -56,7 +57,8 @@ def __post_init__(self): raise ValueError( f"Length mismatch: distributed_backends " f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") + f"vllm_major_versions ({len(self.vllm_major_versions)})" + ) @staticmethod def detailed( @@ -72,18 +74,22 @@ def detailed( for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -100,18 +106,22 @@ def fast( for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -126,28 +136,39 @@ def fp8_quant( parallel_setups = [] for fusion_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - enable_fusion=fusion_val, - eager_mode=True, - chunked_prefill=False)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_base, + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.task, opts) + for backend, vllm_major_version in zip( + self.distributed_backends, self.vllm_major_versions + ): + yield ( + model_id, + parallel_setup, + backend, + vllm_major_version, + self.task, + opts, + ) def _compare_sp( @@ -199,8 +220,10 @@ def _compare_sp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -229,14 +252,14 @@ def _compare_sp( common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) compilation_config = { - 'level': 3, - 'custom_ops': ["+rms_norm"], - 'compile_sizes': [4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_sequence_parallelism': True, - 'enable_fusion': enable_fusion, - 'enable_noop': True, + "level": 3, + "custom_ops": ["+rms_norm"], + "compile_sizes": [4, 8], + "splitting_ops": [], + "pass_config": { + "enable_sequence_parallelism": True, + "enable_fusion": enable_fusion, + "enable_noop": True, }, } @@ -266,12 +289,9 @@ def _compare_sp( ] try: - compare_two_settings(model_id, - tp_sp_args, - tp_args, - tp_sp_env, - tp_env, - method=method) + compare_two_settings( + model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method + ) except Exception: testing_ray_compiled_graph = tp_sp_env is not None if testing_ray_compiled_graph and vllm_major_version == "0": @@ -292,15 +312,22 @@ def _compare_sp( # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "task", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "task", + "test_options", + ), [ - params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + params + for model_id, settings in SP_TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_id) if model_id in SP_TEST_MODELS ], @@ -315,12 +342,14 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, ): - _compare_sp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - task, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_sp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index e1357b4a34e9..cdea1bfe8f28 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -26,13 +26,13 @@ def distributed_run(fn, world_size): processes = [] for i in range(number_of_processes): env = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -57,25 +57,23 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = '127.0.0.1' + ip = "127.0.0.1" dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) ip, port = recv # type: ignore - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: - writer_rank = 2 broadcaster = MessageQueue.create_from_process_group( - pg, 40 * 1024, 2, writer_rank) + pg, 40 * 1024, 2, writer_rank + ) if rank == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 9f2c3eaec359..f415409d7b37 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -24,13 +24,15 @@ # set different `gpu_memory_utilization` and `swap_space` for different ranks, # to test if all ranks agree on the same kv cache configuration. -llm = LLM(model="facebook/opt-125m", - tensor_parallel_size=2, - pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), - distributed_executor_backend="external_launcher", - gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), - seed=0) +llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0, +) outputs = llm.generate(prompts, sampling_params) @@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj): assert container[0] == obj -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) # make sure we can access the model parameters from the calling process # of the `LLM` instance. -params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. - model.parameters()) +params = list( + llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters() +) test_consistent_across_ranks(len(params)) # all ranks should have the same outputs @@ -65,5 +66,4 @@ def test_consistent_across_ranks(obj): generated_text = output.outputs[0].text test_consistent_across_ranks(prompt) test_consistent_across_ranks(generated_text) - print(f"Rank {torch_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 0287ad94e388..2a6936fcd4c2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -10,21 +10,22 @@ import vllm.envs as envs from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import (cuda_device_count_stateless, get_open_port, - update_environment_variables) +from vllm.utils import ( + cuda_device_count_stateless, + get_open_port, + update_environment_variables, +) from ..utils import multi_gpu_test @ray.remote class _CUDADeviceCountStatelessTestActor: - def get_count(self): return cuda_device_count_stateless() def set_cuda_visible_devices(self, cuda_visible_devices: str): - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) def get_cuda_visible_devices(self): return envs.CUDA_VISIBLE_DEVICES @@ -34,10 +35,9 @@ def test_cuda_device_count_stateless(): """Test that cuda_device_count_stateless changes return value if CUDA_VISIBLE_DEVICES is changed.""" actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore - num_gpus=2).remote() - assert len( - sorted(ray.get( - actor.get_cuda_visible_devices.remote()).split(","))) == 2 + num_gpus=2 + ).remote() + assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2 assert ray.get(actor.get_count.remote()) == 2 ray.get(actor.set_cuda_visible_devices.remote("0")) assert ray.get(actor.get_count.remote()) == 1 @@ -46,15 +46,13 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) data = torch.tensor([rank]) data = pg1.broadcast_obj(data, src=2) assert data.item() == 2 @@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2): torch.cuda.set_device(rank) - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) pynccl1 = PyNcclCommunicator(pg1, device=rank) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) pynccl2 = PyNcclCommunicator(pg2, device=rank) data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) @@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank == 2: pg1.broadcast_obj("secret", src=2) else: @@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) data = pg1.all_gather_obj(rank) assert data == list(range(WORLD_SIZE)) pg1.barrier() @@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): @pytest.mark.skip(reason="This test is flaky and prone to hang.") @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize( - "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) + "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker] +) def test_stateless_process_group(worker): port1 = get_open_port() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -129,12 +124,14 @@ def test_stateless_process_group(worker): port2 = get_open_port() WORLD_SIZE = 4 from multiprocessing import get_context + ctx = get_context("fork") processes = [] for i in range(WORLD_SIZE): rank = i processes.append( - ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))) + ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)) + ) for p in processes: p.start() for p in processes: diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 8b99d9d6e21f..26866a95fb55 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -4,22 +4,24 @@ Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. """ + from typing import Optional import pytest from transformers import AutoModelForSeq2SeqLM -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) +from vllm.attention.selector import ( + _Backend, + _cached_get_attn_backend, + global_force_attn_backend_context_manager, +) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close -LIST_ENC_DEC_SUPPORTED_BACKENDS = [ - _Backend.XFORMERS, _Backend.FLASH_ATTN, None -] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN, None] @pytest.fixture(scope="function", autouse=True) @@ -28,7 +30,7 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") def vllm_to_hf_output( @@ -61,7 +63,7 @@ def clear_cache(): @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.skipif( current_platform.is_cpu(), - reason="CPU backend is not currently supported with encoder/decoder models" + reason="CPU backend is not currently supported with encoder/decoder models", ) def test_encoder_decoder_e2e( hf_runner, @@ -75,19 +77,18 @@ def test_encoder_decoder_e2e( enforce_eager: bool, attn_backend: _Backend, ) -> None: - ''' + """ End-to-End (E2E) test for the encoder-decoder framework. This test evaluates the encoder-decoder functionality using the BART model. We compare the outputs of the Hugging Face and vLLM implementations to ensure that both implementations produce consistent and correct results. - ''' + """ with global_force_attn_backend_context_manager(attn_backend): if attn_backend == _Backend.FLASH_ATTN: # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - test_case_prompts = example_encoder_decoder_prompts[ - decoder_prompt_type] + dtype = "bfloat16" + test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] # Configuration settings for HF baseline hf_kwargs = { @@ -98,25 +99,22 @@ def test_encoder_decoder_e2e( "length_penalty": 1.0, "early_stopping": False, "no_repeat_ngram_size": None, - "min_length": 0 + "min_length": 0, } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + ) + with vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) + test_case_prompts, max_tokens, num_logprobs + ) - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE - else 0) + hf_skip_tokens = 1 if decoder_prompt_type == DecoderPromptType.NONE else 0 check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py index 375b248ebeda..a6a8b33e19d3 100644 --- a/tests/engine/conftest.py +++ b/tests/engine/conftest.py @@ -9,4 +9,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 5a91758414a5..90028efb2d81 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -10,22 +10,30 @@ import pytest from vllm.config import CompilationConfig, config -from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, - get_type, get_type_hints, is_not_builtin, - is_type, literal_to_kwargs, optional_type, - parse_type) +from vllm.engine.arg_utils import ( + EngineArgs, + contains_type, + get_kwargs, + get_type, + get_type_hints, + is_not_builtin, + is_type, + literal_to_kwargs, + optional_type, + parse_type, +) from vllm.utils import FlexibleArgumentParser -@pytest.mark.parametrize(("type", "value", "expected"), [ - (int, "42", 42), - (float, "3.14", 3.14), - (str, "Hello World!", "Hello World!"), - (json.loads, '{"foo":1,"bar":2}', { - "foo": 1, - "bar": 2 - }), -]) +@pytest.mark.parametrize( + ("type", "value", "expected"), + [ + (int, "42", 42), + (float, "3.14", 3.14), + (str, "Hello World!", "Hello World!"), + (json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}), + ], +) def test_parse_type(type, value, expected): parse_type_func = parse_type(type) assert parse_type_func(value) == expected @@ -37,43 +45,52 @@ def test_optional_type(): assert optional_type_func("42") == 42 -@pytest.mark.parametrize(("type_hint", "type", "expected"), [ - (int, int, True), - (int, float, False), - (list[int], list, True), - (list[int], tuple, False), - (Literal[0, 1], Literal, True), -]) +@pytest.mark.parametrize( + ("type_hint", "type", "expected"), + [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), + ], +) def test_is_type(type_hint, type, expected): assert is_type(type_hint, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({float, int}, int, True), - ({int, tuple[int]}, int, True), - ({int, tuple[int]}, float, False), - ({str, Literal["x", "y"]}, Literal, True), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({float, int}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), + ], +) def test_contains_type(type_hints, type, expected): assert contains_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({int, float}, int, int), - ({int, float}, str, None), - ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), + ], +) def test_get_type(type_hints, type, expected): assert get_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "expected"), [ - ({Literal[1, 2]}, { - "type": int, - "choices": [1, 2] - }), - ({Literal[1, "a"]}, Exception), -]) +@pytest.mark.parametrize( + ("type_hints", "expected"), + [ + ({Literal[1, 2]}, {"type": int, "choices": [1, 2]}), + ({Literal[1, "a"]}, Exception), + ], +) def test_literal_to_kwargs(type_hints, expected): context = nullcontext() if expected is Exception: @@ -144,22 +161,27 @@ class DummyConfig: """Different config with from_cli method""" -@pytest.mark.parametrize(("type_hint", "expected"), [ - (int, False), - (DummyConfig, True), -]) +@pytest.mark.parametrize( + ("type_hint", "expected"), + [ + (int, False), + (DummyConfig, True), + ], +) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected @pytest.mark.parametrize( - ("type_hint", "expected"), [ + ("type_hint", "expected"), + [ (Annotated[int, "annotation"], {int}), (Optional[int], {int, type(None)}), (Annotated[Optional[int], "annotation"], {int, type(None)}), (Optional[Annotated[int, "annotation"]], {int, type(None)}), ], - ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"]) + ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"], +) def test_get_type_hints(type_hint, expected): assert get_type_hints(type_hint) == expected @@ -199,24 +221,16 @@ def test_get_kwargs(): ("arg", "expected"), [ (None, dict()), - ('{"video": {"num_frames": 123} }', { - "video": { - "num_frames": 123 - } - }), + ('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}), ( '{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa { - "video": { - "num_frames": 123, - "fps": 1.0, - "foo": "bar" - }, - "image": { - "foo": "bar" - } - }), - ]) + "video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, + "image": {"foo": "bar"}, + }, + ), + ], +) def test_media_io_kwargs_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: @@ -251,24 +265,32 @@ def test_compilation_config(): assert args.compilation_config.level == 3 # set to string form of a dict - args = parser.parse_args([ - "-O", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": false}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and not args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor + ) # set to string form of a dict - args = parser.parse_args([ - "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": true}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "--compilation-config=" + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor + ) def test_prefix_cache_default(): @@ -276,8 +298,7 @@ def test_prefix_cache_default(): args = parser.parse_args([]) engine_args = EngineArgs.from_cli_args(args=args) - assert (not engine_args.enable_prefix_caching - ), "prefix caching defaults to off." + assert not engine_args.enable_prefix_caching, "prefix caching defaults to off." # with flag to turn it on. args = parser.parse_args(["--enable-prefix-caching"]) diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py index ac5a1f957dfe..8a05351e879f 100644 --- a/tests/engine/test_computed_prefix_blocks.py +++ b/tests/engine/test_computed_prefix_blocks.py @@ -18,15 +18,17 @@ def test_computed_prefix_blocks(model: str, block_size: int): prompt = ( "You are a helpful assistant. How do I build a car from cardboard and " "paper clips? Is there an easy to follow video tutorial available " - "online for free?") + "online for free?" + ) prompt2 = ( " Please recommend to me some resources where I can learn not only to " "handle technical difficulties of building a car, but also " - "decoration.") + "decoration." + ) - engine_args = EngineArgs(model=model, - block_size=block_size, - enable_prefix_caching=True) + engine_args = EngineArgs( + model=model, block_size=block_size, enable_prefix_caching=True + ) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams() diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py index 15c7a97b50e1..bc6994c5f041 100644 --- a/tests/engine/test_executor.py +++ b/tests/engine/test_executor.py @@ -14,17 +14,17 @@ from vllm.sampling_params import SamplingParams -class Mock: - ... +class Mock: ... class CustomUniExecutor(UniProcExecutor): - - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ) -> list[Any]: # Drop marker to show that this was ran with open(".marker", "w"): ... @@ -37,12 +37,10 @@ def collective_rpc(self, @pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) def test_custom_executor_type_checking(model): with pytest.raises(ValueError): - engine_args = EngineArgs(model=model, - distributed_executor_backend=Mock) + engine_args = EngineArgs(model=model, distributed_executor_backend=Mock) LLMEngine.from_engine_args(engine_args) with pytest.raises(ValueError): - engine_args = AsyncEngineArgs(model=model, - distributed_executor_backend=Mock) + engine_args = AsyncEngineArgs(model=model, distributed_executor_backend=Mock) AsyncLLMEngine.from_engine_args(engine_args) diff --git a/tests/engine/test_multi_step_output_processor.py b/tests/engine/test_multi_step_output_processor.py index 458f4deb743a..9935eeea8b29 100644 --- a/tests/engine/test_multi_step_output_processor.py +++ b/tests/engine/test_multi_step_output_processor.py @@ -11,8 +11,12 @@ from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput, SequenceStatus) +from vllm.sequence import ( + CompletionSequenceGroupOutput, + Logprob, + SequenceOutput, + SequenceStatus, +) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -44,9 +48,9 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): seq_group = create_seq_group( seq_prompt_len=1024, seq_output_lens=[seq_output_len], - sampling_params=SamplingParams(max_tokens=seq_output_len + - num_new_tokens, - ignore_eos=True), + sampling_params=SamplingParams( + max_tokens=seq_output_len + num_new_tokens, ignore_eos=True + ), ) seq = seq_group.get_seqs()[0] @@ -64,12 +68,13 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): ) ], prompt_logprobs=None, - ) for output_token in new_token_ids + ) + for output_token in new_token_ids ] - assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids + assert seq.get_token_ids()[-len(new_token_ids) :] != new_token_ids output_processor.process_outputs(seq_group, outputs) - assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + assert seq.get_token_ids()[-len(new_token_ids) :] == new_token_ids @pytest.mark.parametrize("seq_prompt_len", [1024]) @@ -77,8 +82,9 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): @pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) @pytest.mark.parametrize("max_tokens", [128 + 3]) @pytest.mark.skip_global_cleanup -def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, max_tokens: int): +def test_respects_max_tokens( + num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, max_tokens: int +): """Verify tokens after max_tokens are dropped and not appended to the sequence. """ @@ -98,7 +104,9 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_group = create_seq_group( seq_prompt_len=seq_prompt_len, seq_output_lens=[seq_output_len], - sampling_params=SamplingParams(max_tokens=max_tokens, ), + sampling_params=SamplingParams( + max_tokens=max_tokens, + ), ) seq = seq_group.get_seqs()[0] @@ -116,7 +124,8 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, ) ], prompt_logprobs=None, - ) for output_token in new_token_ids + ) + for output_token in new_token_ids ] assert seq.get_len() == seq_prompt_len + seq_output_len @@ -126,9 +135,11 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, assert seq.get_len() == seq_prompt_len + max_tokens # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens + expected_appended_tokens = new_token_ids[: max_tokens - seq_output_len] + assert ( + seq.get_token_ids()[-len(expected_appended_tokens) :] + == expected_appended_tokens + ) @pytest.mark.parametrize("seq_prompt_len", [1024]) @@ -136,8 +147,9 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, @pytest.mark.parametrize("num_new_tokens", [12]) @pytest.mark.parametrize("seed", list(range(6))) @pytest.mark.skip_global_cleanup -def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, seed: int): +def test_respects_eos_token_id( + num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int +): """Verify the eos token id is included in the sequence, but subsequent tokens are dropped (not appended to sequence). """ @@ -162,7 +174,8 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_lens=[seq_output_len], sampling_params=SamplingParams( # Ensure enough space. - max_tokens=seq_output_len + num_new_tokens, ), + max_tokens=seq_output_len + num_new_tokens, + ), ) seq = seq_group.get_seqs()[0] @@ -183,7 +196,8 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, ) ], prompt_logprobs=None, - ) for output_token in new_token_ids + ) + for output_token in new_token_ids ] assert seq.get_len() == seq_prompt_len + seq_output_len @@ -193,9 +207,11 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:eos_index + 1] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens + expected_appended_tokens = new_token_ids[: eos_index + 1] + assert ( + seq.get_token_ids()[-len(expected_appended_tokens) :] + == expected_appended_tokens + ) @pytest.mark.parametrize("seq_prompt_len", [1024]) @@ -203,8 +219,9 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, @pytest.mark.parametrize("num_new_tokens", [12]) @pytest.mark.parametrize("seed", list(range(6))) @pytest.mark.skip_global_cleanup -def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, seed: int): +def test_ignores_eos_token_id( + num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int +): """When sampling parameters dictate that we should ignore the eos token id, ensure all token ids are appended even if the eos token id is emitted. """ @@ -252,7 +269,8 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, ) ], prompt_logprobs=None, - ) for output_token in new_token_ids + ) + for output_token in new_token_ids ] assert seq.get_len() == seq_prompt_len + seq_output_len @@ -262,10 +280,13 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - - seq_output_len] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens + expected_appended_tokens = new_token_ids[ + : seq_output_len + num_new_tokens - seq_output_len + ] + assert ( + seq.get_token_ids()[-len(expected_appended_tokens) :] + == expected_appended_tokens + ) def mock_tokenizer(eos_token_id=1000): diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index b5381b61a020..3ca19da99ccd 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -10,8 +10,11 @@ import pytest from vllm.config import VllmConfig -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) +from vllm.executor.multiproc_worker_utils import ( + ProcessWorkerWrapper, + ResultHandler, + WorkerMonitor, +) from vllm.worker.worker_base import WorkerWrapperBase @@ -32,8 +35,8 @@ def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]: result_handler = ResultHandler() vllm_config = VllmConfig() workers = [ - ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config, - rank) for rank in range(8) + ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config, rank) + for rank in range(8) ] worker_monitor = WorkerMonitor(workers, result_handler) @@ -53,8 +56,7 @@ def test_local_workers() -> None: def execute_workers(worker_input: str) -> None: worker_outputs = [ - worker.execute_method("worker_method", worker_input) - for worker in workers + worker.execute_method("worker_method", worker_input) for worker in workers ] for rank, output in enumerate(worker_outputs): @@ -152,8 +154,7 @@ async def execute_workers(worker_input: str) -> None: # Test error case exception = ValueError("fake error") try: - _result = await workers[0].execute_method_async( - "worker_method", exception) + _result = await workers[0].execute_method_async("worker_method", exception) pytest.fail("task should have failed") except Exception as e: assert isinstance(e, ValueError) @@ -172,8 +173,7 @@ async def execute_workers(worker_input: str) -> None: # Further attempts to submit tasks should fail try: - _result = await workers[0].execute_method_async( - "worker_method", "test") + _result = await workers[0].execute_method_async("worker_method", "test") pytest.fail("task should fail once workers have been shut down") except Exception as e: assert isinstance(e, ChildProcessError) diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py index 42e88e84770a..b15bf1f4dcb0 100644 --- a/tests/engine/test_options.py +++ b/tests/engine/test_options.py @@ -23,8 +23,9 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) + outputs = llm.generate( + {"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params + ) assert len(outputs) > 0 completions = outputs[0].outputs assert len(completions) > 0 @@ -34,8 +35,7 @@ def test_skip_tokenizer_initialization(model: str): @pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) -def test_enable_prompt_embeds(hf_runner, model: str, - enable_prompt_embeds: bool): +def test_enable_prompt_embeds(hf_runner, model: str, enable_prompt_embeds: bool): prompt = "abc" with hf_runner(model) as hf_model: @@ -45,8 +45,11 @@ def test_enable_prompt_embeds(hf_runner, model: str, embed_layer = hf_model.model.get_input_embeddings() prompt_embeds = embed_layer(token_ids).squeeze(0) - ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( - ValueError, match="set `--enable-prompt-embeds`")) + ctx = ( + nullcontext() + if enable_prompt_embeds + else pytest.raises(ValueError, match="set `--enable-prompt-embeds`") + ) llm = LLM( model=model, diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index 9c62761d78af..f63c0cc596e4 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -5,12 +5,12 @@ from ..conftest import IMAGE_ASSETS -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: \nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: \nWhat is the season?\nASSISTANT:", -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "USER: \nWhat's the content of the image?\nASSISTANT:", + "cherry_blossom": "USER: \nWhat is the season?\nASSISTANT:", + } +) models = ["llava-hf/llava-1.5-7b-hf"] @@ -19,8 +19,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] - with pytest.raises(ValueError, - match="longer than the maximum model length"): + with pytest.raises(ValueError, match="longer than the maximum model length"): vllm_model = vllm_runner( model, max_model_len=128, # LLaVA has a feature size of 576 @@ -28,6 +27,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model): ) with vllm_model: - vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], - max_tokens=1, - images=[images[0]]) + vllm_model.generate_greedy( + [HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]] + ) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index a7c533ec2419..bea264cc8fb5 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -26,8 +26,10 @@ def sample_token_ids(): @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) @pytest.fixture @@ -35,40 +37,27 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -80,65 +69,53 @@ def sample_complex_json_schema(): "score": { "type": "integer", "minimum": 0, - "maximum": 100 # Numeric range + "maximum": 100, # Numeric range }, "grade": { "type": "string", - "pattern": "^[A-D]$" # Regex pattern + "pattern": "^[A-D]$", # Regex pattern }, "email": { "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "tags": { "type": "array", "items": { "type": "string", - "pattern": - "^[a-z]{1,10}$" # Combining length and pattern restrictions - } - } + "pattern": "^[a-z]{1,10}$", # Combining length and pattern restrictions + }, + }, }, - "required": ["score", "grade", "email", "tags"] + "required": ["score", "grade", "email", "tags"], } @pytest.fixture def sample_definition_json_schema(): return { - '$defs': { - 'Step': { - 'properties': { - 'explanation': { - 'title': 'Explanation', - 'type': 'string' - }, - 'output': { - 'title': 'Output', - 'type': 'string' - } + "$defs": { + "Step": { + "properties": { + "explanation": {"title": "Explanation", "type": "string"}, + "output": {"title": "Output", "type": "string"}, }, - 'required': ['explanation', 'output'], - 'title': 'Step', - 'type': 'object' + "required": ["explanation", "output"], + "title": "Step", + "type": "object", } }, - 'properties': { - 'steps': { - 'items': { - '$ref': '#/$defs/Step' - }, - 'title': 'Steps', - 'type': 'array' + "properties": { + "steps": { + "items": {"$ref": "#/$defs/Step"}, + "title": "Steps", + "type": "array", }, - 'final_answer': { - 'title': 'Final Answer', - 'type': 'string' - } + "final_answer": {"title": "Final Answer", "type": "string"}, }, - 'required': ['steps', 'final_answer'], - 'title': 'MathReasoning', - 'type': 'object' + "required": ["steps", "final_answer"], + "title": "MathReasoning", + "type": "object", } @@ -149,55 +126,61 @@ def sample_enum_json_schema(): "properties": { "status": { "type": "string", - "enum": ["active", "inactive", - "pending"] # Literal values using enum + "enum": ["active", "inactive", "pending"], # Literal values using enum }, "priority": { "type": "string", - "enum": ["low", "medium", "high", "critical"] + "enum": ["low", "medium", "high", "critical"], }, "category": { "type": "object", "properties": { "type": { "type": "string", - "enum": ["bug", "feature", "improvement"] + "enum": ["bug", "feature", "improvement"], }, "severity": { "type": "integer", - "enum": [1, 2, 3, 4, - 5] # Enum can also contain numbers - } + "enum": [1, 2, 3, 4, 5], # Enum can also contain numbers + }, }, - "required": ["type", "severity"] + "required": ["type", "severity"], }, "flags": { "type": "array", "items": { "type": "string", - "enum": ["urgent", "blocked", "needs_review", "approved"] - } - } + "enum": ["urgent", "blocked", "needs_review", "approved"], + }, + }, }, - "required": ["status", "priority", "category", "flags"] + "required": ["status", "priority", "category", "flags"], } @pytest.fixture def sample_guided_choice(): return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" + "Python", + "Java", + "JavaScript", + "C++", + "C#", + "PHP", + "TypeScript", + "Ruby", + "Swift", + "Kotlin", ] @pytest.fixture def sample_sql_statements(): - return (""" + return """ start: select_statement select_statement: "SELECT" column "from" table "where" condition column: "col_1" | "col_2" table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""" diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 30a666d4c39c..3b416eef4d73 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -45,20 +45,23 @@ def run_test(model_name, more_args=None): measured_value = results["results"][TASK][FILTER] assert model_name in EXPECTED_VALUES, ( - f"Cannot find the expected value for the model {model_name=}") + f"Cannot find the expected value for the model {model_name=}" + ) expected_value = EXPECTED_VALUES[model_name] - assert (measured_value - RTOL < expected_value - and measured_value + RTOL > expected_value - ), f"Expected: {expected_value} | Measured: {measured_value}" + assert ( + measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" # TODO: [AlexM] Fix it with new CI/CD tests -TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" +TPU_TP_TEST_STR = "" # "tensor_parallel_size=4" -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 is currently only supported on CUDA and TPU") +@pytest.mark.skipif( + not current_platform.is_cuda() and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU", +) @pytest.mark.parametrize("model", MODEL_NAMES) def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 97cf3b5ce8fc..ce641a45fa32 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -14,9 +14,7 @@ def text_llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - seed=0) + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) @@ -29,14 +27,8 @@ def text_llm(): def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -47,25 +39,13 @@ def test_multi_chat(text_llm): prompt2 = "Explain what among us is." conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt2}, ] messages = [conversation1, conversation2] @@ -96,25 +76,20 @@ def vision_llm(): cleanup_dist_env_and_memory() -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +@pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(vision_llm, image_urls: list[str]): - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] outputs = vision_llm.chat(messages) assert len(outputs) >= 0 @@ -125,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm): Check we get a single BOS token for llama chat. """ messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello!" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello!"}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -169,14 +138,8 @@ def thinking_llm(): @pytest.mark.parametrize("enable_thinking", [True, False]) def test_chat_extra_kwargs(thinking_llm, enable_thinking): messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "What is 1+1?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What is 1+1?"}, ] outputs = thinking_llm.chat( diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 3a13f8c979f2..937aa5c13246 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -23,9 +23,11 @@ def echo_rank(self): return self.rank monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=tp_size, - distributed_executor_backend=backend) + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=tp_size, + distributed_executor_backend=backend, + ) assert llm.collective_rpc(echo_rank) == list(range(tp_size)) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index b930f05bebd0..163da19d1465 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -41,12 +41,14 @@ def v1(run_with_both_engines): def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) @@ -56,8 +58,9 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_match(o1: list[PoolingRequestOutput], - o2: list[PoolingRequestOutput]): +def assert_outputs_match( + o1: list[PoolingRequestOutput], o2: list[PoolingRequestOutput] +): check_embeddings_close( embeddings_0_lst=[o.outputs.data for o in o1], embeddings_1_lst=[o.outputs.data for o in o2], @@ -68,17 +71,18 @@ def assert_outputs_match(o1: list[PoolingRequestOutput], @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): +@pytest.mark.parametrize("prompt_token_ids", TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, prompt_token_ids): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) + v1_output = llm.encode( + prompt_token_ids=prompt_token_ids, pooling_params=pooling_params + ) - v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) + v2_output = llm.encode( + {"prompt_token_ids": prompt_token_ids}, pooling_params=pooling_params + ) assert_outputs_match(v1_output, v2_output) @@ -87,13 +91,12 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) + v1_output = llm.encode( + prompt_token_ids=TOKEN_IDS, pooling_params=pooling_params + ) v2_output = llm.encode( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], + [{"prompt_token_ids": p} for p in TOKEN_IDS], pooling_params=pooling_params, ) assert_outputs_match(v1_output, v2_output) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 707891f6bdd8..bde2d7c31c09 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -35,11 +35,13 @@ def v1(run_with_both_engines): def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True, + ) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) @@ -54,17 +56,18 @@ def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): +@pytest.mark.parametrize("prompt_token_ids", TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, prompt_token_ids): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) + v1_output = llm.generate( + prompt_token_ids=prompt_token_ids, sampling_params=sampling_params + ) - v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, - sampling_params=sampling_params) + v2_output = llm.generate( + {"prompt_token_ids": prompt_token_ids}, sampling_params=sampling_params + ) assert_outputs_equal(v1_output, v2_output) @@ -73,13 +76,12 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, - sampling_params=sampling_params) + v1_output = llm.generate( + prompt_token_ids=TOKEN_IDS, sampling_params=sampling_params + ) v2_output = llm.generate( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], + [{"prompt_token_ids": p} for p in TOKEN_IDS], sampling_params=sampling_params, ) assert_outputs_equal(v1_output, v2_output) @@ -124,7 +126,8 @@ def test_max_model_len(): outputs = llm.generate(PROMPTS, sampling_params) for output in outputs: num_total_tokens = len(output.prompt_token_ids) + len( - output.outputs[0].token_ids) + output.outputs[0].token_ids + ) # Total tokens must not exceed max_model_len. # It can be less if generation finishes due to other reasons (e.g., EOS) # before reaching the absolute model length limit. diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index b7d53e31fd71..a99d6ea01f76 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -4,6 +4,7 @@ import weakref import pytest + # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -26,6 +27,7 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @@ -33,20 +35,21 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[False, True]) def llm(request, monkeypatch_module): - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') + monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - tensor_parallel_size=1, - max_model_len=8192, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - max_num_seqs=128, - enforce_eager=True) + llm = LLM( + model=MODEL_NAME, + tensor_parallel_size=1, + max_model_len=8192, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + max_num_seqs=128, + enforce_eager=True, + ) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py index 533da9e6d6ea..896091533ad2 100644 --- a/tests/entrypoints/llm/test_gpu_utilization.py +++ b/tests/entrypoints/llm/test_gpu_utilization.py @@ -16,9 +16,8 @@ def test_gpu_memory_utilization(): # makes sure gpu_memory_utilization is per-instance limit, # not a global limit llms = [ - LLM(model="facebook/opt-125m", - gpu_memory_utilization=0.3, - enforce_eager=True) for i in range(3) + LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True) + for i in range(3) ] for llm in llms: outputs = llm.generate(prompts, sampling_params) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 55578341cb2e..1797a7beaa5e 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -26,7 +26,7 @@ ("guidance", True), ] -ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS) +ALL_DECODING_BACKENDS = [("outlines", False)] + GRAMMAR_DECODING_BACKENDS @pytest.fixture(scope="module") @@ -42,23 +42,27 @@ def llm(): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_regex( + sample_regex, llm, guided_decoding_backend: str, disable_any_whitespace: bool +): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( regex=sample_regex, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) + disable_any_whitespace=disable_any_whitespace, + ), + ) - outputs = llm.generate(prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts=[f"Give an example IPv4 address with this regex: {sample_regex}"] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -73,24 +77,30 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_json_completion(sample_json_schema, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_json_completion( + sample_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool +): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_json_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) - outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + disable_any_whitespace=disable_any_whitespace, + ), + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] + * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -107,24 +117,33 @@ def test_guided_json_completion(sample_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_complex_json_completion(sample_complex_json_schema, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_complex_json_completion( + sample_complex_json_schema, + llm, + guided_decoding_backend: str, + disable_any_whitespace: bool, +): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_complex_json_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) - outputs = llm.generate(prompts=[ - f"Give an example JSON for an assignment grade " - f"that fits this schema: {sample_complex_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + disable_any_whitespace=disable_any_whitespace, + ), + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an assignment grade " + f"that fits this schema: {sample_complex_json_schema}" + ] + * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -137,29 +156,37 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, - schema=sample_complex_json_schema) + jsonschema.validate(instance=output_json, schema=sample_complex_json_schema) @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_definition_json_completion(sample_definition_json_schema, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_definition_json_completion( + sample_definition_json_schema, + llm, + guided_decoding_backend: str, + disable_any_whitespace: bool, +): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_definition_json_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) - outputs = llm.generate(prompts=[ - f"Give an example JSON for solving 8x + 7 = -23 " - f"that fits this schema: {sample_definition_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + disable_any_whitespace=disable_any_whitespace, + ), + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for solving 8x + 7 = -23 " + f"that fits this schema: {sample_definition_json_schema}" + ] + * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -172,29 +199,37 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, - schema=sample_definition_json_schema) + jsonschema.validate(instance=output_json, schema=sample_definition_json_schema) @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_enum_json_completion(sample_enum_json_schema, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_enum_json_completion( + sample_enum_json_schema, + llm, + guided_decoding_backend: str, + disable_any_whitespace: bool, +): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_enum_json_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) - outputs = llm.generate(prompts=[ - "Create a bug report JSON that fits this schema: " - f"{sample_enum_json_schema}. Make it for a high priority critical bug." - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + disable_any_whitespace=disable_any_whitespace, + ), + ) + outputs = llm.generate( + prompts=[ + "Create a bug report JSON that fits this schema: " + f"{sample_enum_json_schema}. Make it for a high priority critical bug." + ] + * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -207,37 +242,41 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, - schema=sample_enum_json_schema) + jsonschema.validate(instance=output_json, schema=sample_enum_json_schema) # Additional assertions to verify enum values assert output_json["status"] in ["active", "inactive", "pending"] assert output_json["priority"] in ["low", "medium", "high", "critical"] - assert output_json["category"]["type"] in [ - "bug", "feature", "improvement" - ] + assert output_json["category"]["type"] in ["bug", "feature", "improvement"] assert output_json["category"]["severity"] in [1, 2, 3, 4, 5] for flag in output_json["flags"]: assert flag in ["urgent", "blocked", "needs_review", "approved"] @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_choice_completion(sample_guided_choice, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_choice_completion( + sample_guided_choice, + llm, + guided_decoding_backend: str, + disable_any_whitespace: bool, +): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( choice=sample_guided_choice, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) + disable_any_whitespace=disable_any_whitespace, + ), + ) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -252,11 +291,15 @@ def test_guided_choice_completion(sample_guided_choice, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GRAMMAR_DECODING_BACKENDS) -def test_guided_grammar(sample_sql_statements, llm, - guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", GRAMMAR_DECODING_BACKENDS +) +def test_guided_grammar( + sample_sql_statements, + llm, + guided_decoding_backend: str, + disable_any_whitespace: bool, +): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, @@ -264,10 +307,14 @@ def test_guided_grammar(sample_sql_statements, llm, guided_decoding=GuidedDecodingParams( grammar=sample_sql_statements, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) + disable_any_whitespace=disable_any_whitespace, + ), + ) outputs = llm.generate( - prompts=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), + prompts=( + "Generate a sql state that select col_1 from " + "table_1 where it is equals to 1" + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -282,12 +329,12 @@ def test_guided_grammar(sample_sql_statements, llm, assert generated_text is not None # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark + parser = Lark(sample_sql_statements) parser.parse(generated_text) # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") assert generated_text.strip() == ground_truth @@ -299,10 +346,12 @@ def test_guided_options_request_deprecation_warning(sample_regex, llm): sampling_params = SamplingParams(temperature=0.8, top_p=0.95) with pytest.warns(DeprecationWarning, match="guided_options_request"): - llm.generate(prompts="This should fail", - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + llm.generate( + prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex), + ) @pytest.mark.skip_global_cleanup @@ -310,13 +359,16 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - guided_decoding=GuidedDecodingParams(regex=sample_regex)) + guided_decoding=GuidedDecodingParams(regex=sample_regex), + ) with pytest.raises(ValueError, match="Cannot set both"): - llm.generate(prompts="This should fail", - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + llm.generate( + prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex), + ) @pytest.mark.skip_global_cleanup @@ -327,31 +379,35 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): "properties": { "example": { "type": "string", - "minLength": 5 # unsupported by xgrammar + "minLength": 5, # unsupported by xgrammar } - } + }, } - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - json=unsupported_json, - backend="xgrammar", - disable_fallback=True)) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + json=unsupported_json, backend="xgrammar", disable_fallback=True + ), + ) with pytest.raises( - ValueError, - match="xgrammar does not support advanced JSON schema features " - "like string length, item limits, or property bounds."): - llm.generate(prompts="This should fail", - sampling_params=sampling_params, - use_tqdm=True) + ValueError, + match="xgrammar does not support advanced JSON schema features " + "like string length, item limits, or property bounds.", + ): + llm.generate( + prompts="This should fail", sampling_params=sampling_params, use_tqdm=True + ) @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GRAMMAR_DECODING_BACKENDS) -def test_guided_json_object(llm, guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", GRAMMAR_DECODING_BACKENDS +) +def test_guided_json_object( + llm, guided_decoding_backend: str, disable_any_whitespace: bool +): sampling_params = SamplingParams( temperature=1.0, max_tokens=100, @@ -359,13 +415,18 @@ def test_guided_json_object(llm, guided_decoding_backend: str, guided_decoding=GuidedDecodingParams( json_object=True, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) + disable_any_whitespace=disable_any_whitespace, + ), + ) outputs = llm.generate( - prompts=("Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old."), + prompts=( + "Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old." + ), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -401,10 +462,12 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_json_completion_with_enum( + llm, guided_decoding_backend: str, disable_any_whitespace: bool +): json_schema = CarDescription.model_json_schema() sampling_params = SamplingParams( temperature=1.0, @@ -412,12 +475,15 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, guided_decoding=GuidedDecodingParams( json=json_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace)) + disable_any_whitespace=disable_any_whitespace, + ), + ) outputs = llm.generate( prompts="Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's", sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -433,27 +499,18 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - ALL_DECODING_BACKENDS) -def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, - disable_any_whitespace: bool): +@pytest.mark.parametrize( + "guided_decoding_backend,disable_any_whitespace", ALL_DECODING_BACKENDS +) +def test_guided_number_range_json_completion( + llm, guided_decoding_backend: str, disable_any_whitespace: bool +): sample_output_schema = { "type": "object", "properties": { - "age": { - "type": "integer", - "minimum": 18, - "maximum": 99 - }, - "score": { - "type": "number", - "minimum": 0.0, - "maximum": 100.0 - }, - "zipcode": { - "type": "string", - "pattern": r"^\d{5}(-\d{4})?$" - }, + "age": {"type": "integer", "minimum": 18, "maximum": 99}, + "score": {"type": "number", "minimum": 0.0, "maximum": 100.0}, + "zipcode": {"type": "string", "pattern": r"^\d{5}(-\d{4})?$"}, }, "required": ["age", "score", "zipcode"], } @@ -463,12 +520,11 @@ def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, guided_decoding=GuidedDecodingParams( json=sample_output_schema, backend=guided_decoding_backend, - disable_any_whitespace=disable_any_whitespace), + disable_any_whitespace=disable_any_whitespace, + ), ) outputs = llm.generate( - prompts=[ - "Create a JSON object for a user with age, score, and zipcode." - ] * 2, + prompts=["Create a JSON object for a user with age, score, and zipcode."] * 2, sampling_params=sampling_params, use_tqdm=True, ) @@ -487,43 +543,38 @@ def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, jsonschema.validate(instance=output_json, schema=sample_output_schema) assert 18 <= output_json["age"] <= 99 assert 0.0 <= output_json["score"] <= 100.0 - assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"]) - is not None) + assert re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"]) is not None @pytest.mark.skip_global_cleanup def test_guidance_no_additional_properties(llm): schema = { - 'type': 'object', - 'properties': { - 'a1': { - 'type': 'string' - }, - 'a2': { - 'type': 'string' - }, - 'a3': { - 'type': 'string' - } + "type": "object", + "properties": { + "a1": {"type": "string"}, + "a2": {"type": "string"}, + "a3": {"type": "string"}, }, - 'required': ['a1', 'a2', 'a3'], + "required": ["a1", "a2", "a3"], } prompt = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" - "<|im_end|>\n<|im_start|>assistant\n") + "<|im_end|>\n<|im_start|>assistant\n" + ) def generate_with_backend(backend, disable_additional_properties): guided_params = GuidedDecodingParams( json=schema, backend=backend, disable_any_whitespace=True, - disable_additional_properties=disable_additional_properties) - sampling_params = SamplingParams(temperature=0, - max_tokens=256, - guided_decoding=guided_params) + disable_additional_properties=disable_additional_properties, + ) + sampling_params = SamplingParams( + temperature=0, max_tokens=256, guided_decoding=guided_params + ) outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) assert outputs is not None diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 61b6b4fbf8e3..c8679852a54f 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -5,10 +5,10 @@ from contextlib import nullcontext import pytest -from vllm_test_utils import BlameResult, blame from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory +from vllm_test_utils import BlameResult, blame @pytest.fixture(scope="function", autouse=True) @@ -16,7 +16,7 @@ def use_v0_only(monkeypatch): """ V1 only supports xgrammar so this is irrelevant. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") def run_normal_opt125m(): @@ -29,9 +29,7 @@ def run_normal_opt125m(): sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM without guided decoding as a baseline. - llm = LLM(model="facebook/opt-125m", - enforce_eager=True, - gpu_memory_utilization=0.3) + llm = LLM(model="facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.3) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt @@ -53,9 +51,9 @@ def run_normal(): sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM without guided decoding as a baseline. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - gpu_memory_utilization=0.3) + llm = LLM( + model="distilbert/distilgpt2", enforce_eager=True, gpu_memory_utilization=0.3 + ) outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt @@ -69,18 +67,19 @@ def run_normal(): def run_lmfe(sample_regex): # Create an LLM with guided decoding enabled. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - guided_decoding_backend="lm-format-enforcer", - gpu_memory_utilization=0.3) + llm = LLM( + model="distilbert/distilgpt2", + enforce_eager=True, + guided_decoding_backend="lm-format-enforcer", + gpu_memory_utilization=0.3, + ) sampling_params = SamplingParams(temperature=0.8, top_p=0.95) outputs = llm.generate( - prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, + prompts=[f"Give an example IPv4 address with this regex: {sample_regex}"] * 2, sampling_params=sampling_params, use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + guided_options_request=dict(guided_regex=sample_regex), + ) for output in outputs: prompt = output.prompt @@ -89,8 +88,7 @@ def run_lmfe(sample_regex): def test_lazy_outlines(sample_regex): - """If users don't use guided decoding, outlines should not be imported. - """ + """If users don't use guided decoding, outlines should not be imported.""" # make sure outlines is not imported module_name = "outlines" # In CI, we only check finally if the module is imported. @@ -99,8 +97,7 @@ def test_lazy_outlines(sample_regex): # and help find the root cause. # We don't run it in CI by default because it is slow. use_blame = False - context = blame( - lambda: module_name in sys.modules) if use_blame else nullcontext() + context = blame(lambda: module_name in sys.modules) if use_blame else nullcontext() with context as result: run_normal() run_lmfe(sample_regex) @@ -109,4 +106,5 @@ def test_lazy_outlines(sample_regex): print(f"the first import location is:\n{result.trace_stack}") assert module_name not in sys.modules, ( f"Module {module_name} is imported. To see the first" - f" import location, run the test with `use_blame=True`.") + f" import location, run the test with `use_blame=True`." + ) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 1b7be15d5d69..a63e9dd5240f 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -16,12 +16,12 @@ def v1(run_with_both_engines): def test_empty_prompt(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='decoder prompt cannot be empty'): + with pytest.raises(ValueError, match="decoder prompt cannot be empty"): llm.generate([""]) @pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='out of vocabulary'): + with pytest.raises(ValueError, match="out of vocabulary"): llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a606eeab5887..ab5e20c0df74 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" + import importlib import sys @@ -88,12 +89,11 @@ def disable_connect(*args, **kwargs): def _re_import_modules(): - hf_hub_module_names = [ - k for k in sys.modules if k.startswith("huggingface_hub") - ] + hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")] transformers_module_names = [ - k for k in sys.modules if k.startswith("transformers") - and not k.startswith("transformers_modules") + k + for k in sys.modules + if k.startswith("transformers") and not k.startswith("transformers_modules") ] reload_exception = None diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index 41b70f80e3b8..fe2f88577294 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -27,7 +27,7 @@ [], # Default ["--enable-chunked-prefill"], # Chunked ["--num-scheduler-steps", "8"], # MS - ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream + ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"], # MS+Stream ] MAX_WAIT_SECONDS = None @@ -47,14 +47,15 @@ def run_test(more_args): print(f"Running with: {args}") with RemoteOpenAIServer( - MODEL_NAME, args, - max_wait_seconds=MAX_WAIT_SECONDS) as remote_server: + MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS + ) as remote_server: url = f"{remote_server.url_for('v1')}/completions" model_args = ( f"model={MODEL_NAME}," f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -63,14 +64,16 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 currently only supported on CUDA and TPU") +@pytest.mark.skipif( + not current_platform.is_cuda() and not current_platform.is_tpu(), + reason="V1 currently only supported on CUDA and TPU", +) def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" @@ -86,8 +89,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) -def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch, - more_args): +def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch, more_args): """Run with the V0 Engine.""" with monkeypatch.context() as m: diff --git a/tests/entrypoints/openai/correctness/test_mteb_embed.py b/tests/entrypoints/openai/correctness/test_mteb_embed.py index 12a86f9bdd59..48e8d0fce86a 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_embed.py +++ b/tests/entrypoints/openai/correctness/test_mteb_embed.py @@ -4,10 +4,12 @@ import pytest -from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, - MTEB_EMBED_TOL, - OpenAIClientMtebEncoder, - run_mteb_embed_task) +from tests.models.language.pooling.mteb_utils import ( + MTEB_EMBED_TASKS, + MTEB_EMBED_TOL, + OpenAIClientMtebEncoder, + run_mteb_embed_task, +) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" @@ -18,9 +20,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--task", "embed", "--enforce-eager", "--disable-uvicorn-access-log" - ] + args = ["--task", "embed", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/openai/correctness/test_mteb_score.py index 05e953de4a0f..569aabb8c53f 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/openai/correctness/test_mteb_score.py @@ -7,9 +7,15 @@ # yapf conflicts with isort for this block # yapf: disable from tests.models.language.pooling.mteb_utils import ( - MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, - RerankClientMtebEncoder, ScoreClientMtebEncoder, - mteb_test_rerank_models_hf, run_mteb_rerank) + MTEB_RERANK_LANGS, + MTEB_RERANK_TASKS, + MTEB_RERANK_TOL, + RerankClientMtebEncoder, + ScoreClientMtebEncoder, + mteb_test_rerank_models_hf, + run_mteb_rerank, +) + # yapf: enable from tests.utils import RemoteOpenAIServer @@ -20,9 +26,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--task", "score", "--enforce-eager", "--disable-uvicorn-access-log" - ] + args = ["--task", "score", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -39,8 +43,7 @@ def st_main_score(hf_runner): def test_mteb_score(server, st_main_score): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) @@ -52,8 +55,7 @@ def test_mteb_score(server, st_main_score): def test_mteb_rerank(server, st_main_score): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd35..17c5867372c4 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -7,6 +7,7 @@ This simulates real work usage of the API and makes sure that the frontend and AsyncLLMEngine are working correctly. """ + import asyncio import io import time @@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr): # NOTE there's no streaming in transcriptions, can't measure ttft latency = end_time - start_time num_output_tokens = len( - tokenizer(transcription.text, add_special_tokens=False).input_ids) + tokenizer(transcription.text, add_special_tokens=False).input_ids + ) return latency, num_output_tokens, transcription.text @@ -71,7 +73,8 @@ async def process_dataset(model, client, data, concurrent_request): for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(model, sem, client, (audio, sr), sample["text"]) + ) tasks.append(task) return await asyncio.gather(*tasks) @@ -95,34 +98,35 @@ def print_performance_metrics(results, total_time): def add_duration(sample): - y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] - sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000 return sample -def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): +def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs): ## Load and filter the dataset dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) - if 'duration_ms' not in dataset[0]: + if "duration_ms" not in dataset[0]: # compute duration to filter dataset = dataset.map(add_duration) # Whisper max supported duration - dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + dataset = dataset.filter(lambda example: example["duration_ms"] < 30000) return dataset -def run_evaluation(model: str, - client, - dataset, - max_concurrent_reqs: int, - n_examples: int = -1, - print_metrics: bool = True): +def run_evaluation( + model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True, +): if n_examples > 0: dataset = dataset.select(range(n_examples)) start = time.perf_counter() - results = asyncio.run( - process_dataset(model, client, dataset, max_concurrent_reqs)) + results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs)) end = time.perf_counter() total_time = end - start print(f"Total Test Time: {total_time:.4f} seconds") @@ -132,8 +136,7 @@ def run_evaluation(model: str, predictions = [res[2] for res in results] references = [res[3] for res in results] wer = load("wer") - wer_score = 100 * wer.compute(references=references, - predictions=predictions) + wer_score = 100 * wer.compute(references=references, predictions=predictions) print("WER:", wer_score) return wer_score @@ -142,26 +145,25 @@ def run_evaluation(model: str, @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) # Original dataset is 20GB+ in size, hence we use a pre-filtered slice. @pytest.mark.parametrize( - "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"] +) # NOTE: Expected WER measured with equivalent hf.transformers args: # whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. @pytest.mark.parametrize("expected_wer", [12.744980]) -def test_wer_correctness(model_name, - dataset_repo, - expected_wer, - n_examples=-1, - max_concurrent_request=None): +def test_wer_correctness( + model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None +): # TODO refactor to use `ASRDataset` - with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server: dataset = load_hf_dataset(dataset_repo) if not max_concurrent_request: # No max concurrency - max_concurrent_request = n_examples if n_examples > 0\ - else len(dataset) + max_concurrent_request = n_examples if n_examples > 0 else len(dataset) client = remote_server.get_async_client() - wer = run_evaluation(model_name, client, dataset, - max_concurrent_request, n_examples) + wer = run_evaluation( + model_name, client, dataset, max_concurrent_request, n_examples + ) if expected_wer: torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py index ab3c80905438..c81e2ec2190f 100644 --- a/tests/entrypoints/openai/test_async_tokenization.py +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -47,15 +47,11 @@ async def client(server): ids=["completion", "chat"], argnames=["create_func_gen", "content_body"], argvalues=[ - (lambda x: x.completions.create, { - "prompt": " ".join(['A'] * 10_000) - }), - (lambda x: x.chat.completions.create, { - "messages": [{ - "role": "user", - "content": " ".join(['A'] * 10_000) - }] - }), + (lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}), + ( + lambda x: x.chat.completions.create, + {"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]}, + ), ], ) async def test_with_and_without_truncate( @@ -68,15 +64,15 @@ async def test_with_and_without_truncate( body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} num_requests = 10 - truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * - (num_requests - num_requests // 2)) + truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * ( + num_requests - num_requests // 2 + ) random.shuffle(truncate_prompt_tokens) - bodies = [{ - **body, "extra_body": { - 'truncate_prompt_tokens': t - } - } for t in truncate_prompt_tokens] + bodies = [ + {**body, "extra_body": {"truncate_prompt_tokens": t}} + for t in truncate_prompt_tokens + ] async def get_status_code(**kwargs): try: @@ -94,18 +90,12 @@ async def get_status_code(**kwargs): ids=["single completion", "multiple completions", "chat"], argnames=["create_func_gen", "content_body"], argvalues=[ - (lambda x: x.completions.create, { - "prompt": " ".join(['A'] * 300_000) - }), - (lambda x: x.completions.create, { - "prompt": [" ".join(['A'] * 300_000)] * 2 - }), - (lambda x: x.chat.completions.create, { - "messages": [{ - "role": "user", - "content": " ".join(['A'] * 300_000) - }] - }), + (lambda x: x.completions.create, {"prompt": " ".join(["A"] * 300_000)}), + (lambda x: x.completions.create, {"prompt": [" ".join(["A"] * 300_000)] * 2}), + ( + lambda x: x.chat.completions.create, + {"messages": [{"role": "user", "content": " ".join(["A"] * 300_000)}]}, + ), ], ) async def test_healthcheck_response_time( @@ -127,9 +117,7 @@ def get_response_time(url): return end_time - start_time no_load_response_time = get_response_time(server.url_for("health")) - tasks = [ - asyncio.create_task(create_func(**body)) for _ in range(num_requests) - ] + tasks = [asyncio.create_task(create_func(**body)) for _ in range(num_requests)] await asyncio.sleep(1) # give the tasks a chance to start running load_response_time = get_response_time(server.url_for("health")) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index d67c05ab3e8d..e9e73f88a7bb 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -54,24 +54,18 @@ def base64_encoded_audio() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_single_chat_session_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_single_chat_session_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -80,13 +74,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -108,56 +104,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, - model_name: str, - audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": audio_url - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_error_on_invalid_audio_url_type( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": audio_url}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # audio_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_audio_base64encoded( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": - f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -166,13 +158,15 @@ async def test_single_chat_session_audio_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -196,25 +190,26 @@ async def test_single_chat_session_audio_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_input_audio( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -222,13 +217,15 @@ async def test_single_chat_session_input_audio( messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -250,24 +247,18 @@ async def test_single_chat_session_input_audio( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -307,27 +298,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str, - base64_encoded_audio: dict[str, - str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_input_audio( + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -367,26 +358,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]) -async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, - audio_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "audio_url", - "audio_url": { - "url": audio_url - } - } for audio_url in audio_urls), - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]] +) +async def test_multi_audio_input( + client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "audio_url", "audio_url": {"url": audio_url}} + for audio_url in audio_urls + ), + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] if len(audio_urls) > MAXIMUM_AUDIOS: with pytest.raises(openai.BadRequestError): # test multi-audio input diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a55941976cd8..50ec87b4464f 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -16,9 +16,9 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def server_args(request: pytest.FixtureRequest) -> list[str]: - """ Provide extra arguments to the server via indirect parametrization + """Provide extra arguments to the server via indirect parametrization Usage: @@ -80,8 +80,10 @@ async def client(server): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer): @pytest.mark.parametrize( "server_args", [ - pytest.param(["--max-model-len", "10100"], - id="default-frontend-multiprocessing"), + pytest.param( + ["--max-model-len", "10100"], id="default-frontend-multiprocessing" + ), pytest.param( ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], - id="disable-frontend-multiprocessing") + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # Request about 2 million tokens for _ in range(200): task = asyncio.create_task( - client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10000, - extra_body={"min_tokens": 10000})) + client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000}, + ) + ) tasks.append(task) - done, pending = await asyncio.wait(tasks, - return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) # Make sure all requests were sent to the server and timed out # (We don't want to hide other errors like 400s that would invalidate this @@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # If the server had not cancelled all the other requests, then it would not # be able to respond to this one within the timeout client = server.get_async_client(timeout=5) - response = await client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10) + response = await client.chat.completions.create( + messages=chat_input, model=MODEL_NAME, max_tokens=10 + ) assert len(response.choices) == 1 @pytest.mark.asyncio async def test_request_wrong_content_type(server: RemoteOpenAIServer): - chat_input = [{"role": "user", "content": "Write a long story"}] client = server.get_async_client() @@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer): messages=chat_input, model=MODEL_NAME, max_tokens=10000, - extra_headers={ - "Content-Type": "application/x-www-form-urlencoded" - }) + extra_headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) @pytest.mark.parametrize( "server_args", - [ - pytest.param(["--enable-server-load-tracking"], - id="enable-server-load-tracking") - ], + [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")], indirect=True, ) @pytest.mark.asyncio @@ -202,7 +205,8 @@ def make_long_completion_request(): # Start the completion request in a background thread. completion_future = asyncio.create_task( - asyncio.to_thread(make_long_completion_request)) + asyncio.to_thread(make_long_completion_request) + ) # Give a short delay to ensure the request has started. await asyncio.sleep(0.1) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index e7c3ffaa6a9f..a21be27e61ae 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -15,8 +15,10 @@ from openai import BadRequestError, OpenAI from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 +from .test_completion import ( + zephyr_lora_added_tokens_files, # noqa: F401 + zephyr_lora_files, # noqa: F401 +) # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -25,6 +27,7 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @@ -32,13 +35,13 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[False, True]) def server( - request, - monkeypatch_module, - zephyr_lora_files, #noqa: F811 - zephyr_lora_added_tokens_files): # noqa: F811 - + request, + monkeypatch_module, + zephyr_lora_files, # noqa: F811 + zephyr_lora_added_tokens_files, +): # noqa: F811 use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') + monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") args = [ # use half precision for speed and memory savings in CI environment @@ -67,8 +70,9 @@ def server( @pytest.fixture def is_v1_server(server): import os - assert os.environ['VLLM_USE_V1'] in ['0', '1'] - return os.environ['VLLM_USE_V1'] == '1' + + assert os.environ["VLLM_USE_V1"] in ["0", "1"] + return os.environ["VLLM_USE_V1"] == "1" @pytest_asyncio.fixture @@ -84,20 +88,18 @@ async def client(server): [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], ) async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_completion_tokens=5, temperature=0.0, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.logprobs is None @@ -110,13 +112,10 @@ async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -124,7 +123,8 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=0) + top_logprobs=0, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -138,13 +138,10 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -152,7 +149,8 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -165,41 +163,39 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=21, - stream=True) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=21, + stream=True, + ) async for chunk in stream: ... with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=30, - stream=False) + await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=30, + stream=False, + ) # the server should still work afterwards chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=10, - stream=False) + model=model_name, messages=messages, max_completion_tokens=10, stream=False + ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 @@ -209,27 +205,20 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, "model_name, prompt_logprobs", [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], ) -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +async def test_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, } if prompt_logprobs is not None: @@ -252,29 +241,21 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str): +async def test_more_than_one_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name, - "extra_body": { - "prompt_logprobs": 1 - } + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, + "extra_body": {"prompt_logprobs": 1}, } completion_1 = await client.chat.completions.create(**params) @@ -291,15 +272,11 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_single_chat_session(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -307,14 +284,16 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=37, total_tokens=47) + completion_tokens=10, prompt_tokens=37, total_tokens=47 + ) message = choice.message assert message.content is not None and len(message.content) >= 10 @@ -339,13 +318,10 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, [MODEL_NAME, "zephyr-lora"], ) async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -387,15 +363,13 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], ) -async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }] +async def test_chat_completion_stream_options( + client: openai.AsyncOpenAI, model_name: str +): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] # Test stream=True, stream_options={"include_usage": False} stream = await client.chat.completions.create( @@ -404,23 +378,21 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=True, - stream_options={"include_usage": False}) + stream_options={"include_usage": False}, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": False}} - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": - True, - "continuous_usage_stats": - False - }) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + stream=True, + stream_options={"include_usage": True, "continuous_usage_stats": False}, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: @@ -432,8 +404,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options={"include_usage": None} @@ -444,7 +416,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": None}) + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options={"include_usage": True} with pytest.raises(BadRequestError): @@ -454,7 +427,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": True}) + stream_options={"include_usage": True}, + ) # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": True} @@ -473,92 +447,86 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, last_completion_tokens = 0 async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 - assert last_completion_tokens == 0 or \ - chunk.usage.completion_tokens > last_completion_tokens or \ - ( - not chunk.choices and - chunk.usage.completion_tokens == last_completion_tokens - ) - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens + ) + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) last_completion_tokens = chunk.usage.completion_tokens assert last_completion_tokens == 10 @pytest.mark.asyncio -async def test_guided_choice_chat(client: openai.AsyncOpenAI, - sample_guided_choice): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] +async def test_guided_choice_chat(client: openai.AsyncOpenAI, sample_guided_choice): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict(guided_choice=sample_guided_choice), + ) choice1 = chat_completion.choices[0].message.content assert choice1 in sample_guided_choice messages.append({"role": "assistant", "content": choice1}) - messages.append({ - "role": "user", - "content": "I disagree, pick another one" - }) + messages.append({"role": "user", "content": "I disagree, pick another one"}) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict(guided_choice=sample_guided_choice), + ) choice2 = chat_completion.choices[0].message.content assert choice2 in sample_guided_choice assert choice1 != choice2 @pytest.mark.asyncio -async def test_guided_json_chat(client: openai.AsyncOpenAI, - sample_json_schema): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] +async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(guided_json=sample_json_schema), + ) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": message.content}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(guided_json=sample_json_schema), + ) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -569,21 +537,19 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example IP address with this regex: {sample_regex}" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example IP address with this regex: {sample_regex}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(guided_regex=sample_regex), + ) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(sample_regex, ip1) is not None @@ -594,7 +560,8 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex): model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(guided_regex=sample_regex), + ) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(sample_regex, ip2) is not None @@ -603,45 +570,41 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex): @pytest.mark.asyncio async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - extra_body=dict(guided_regex={ - 1: "Python", - 2: "C++" - })) + _ = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body=dict(guided_regex={1: "Python", 2: "C++"}), + ) @pytest.mark.asyncio -async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, - sample_guided_choice): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] +async def test_guided_choice_chat_logprobs( + client: openai.AsyncOpenAI, sample_guided_choice +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict(guided_choice=sample_guided_choice), + ) assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.content is not None @@ -654,16 +617,14 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] # non-streaming @@ -671,20 +632,17 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }, + ], + tool_choice={"type": "function", "function": {"name": "dummy_function_name"}}, ) message = chat_completion.choices[0].message assert len(message.content) == 0 @@ -693,12 +651,9 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": json_string}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) # streaming @@ -706,21 +661,19 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }, - stream=True) + ], + tool_choice={"type": "function", "function": {"name": "dummy_function_name"}}, + stream=True, + ) output = [] finish_reason_count = 0 @@ -743,11 +696,11 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_required_tool_use(client: openai.AsyncOpenAI, - is_v1_server: bool, model_name: str): +async def test_required_tool_use( + client: openai.AsyncOpenAI, is_v1_server: bool, model_name: str +): if is_v1_server: - pytest.skip( - "tool_choice='required' requires features unsupported on V1") + pytest.skip("tool_choice='required' requires features unsupported on V1") tools = [ { @@ -760,20 +713,16 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, "properties": { "city": { "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", + "description": "The city to find the weather for, e.g. 'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. 'Austria'", }, "unit": { "type": "string", - "description": - "The unit to fetch the temperature in", + "description": "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, }, @@ -791,26 +740,20 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, "properties": { "city": { "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", + "description": "The city to get the forecast for, e.g. 'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. 'Austria'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, "unit": { "type": "string", - "description": - "The unit to fetch the temperature in", + "description": "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, }, @@ -821,19 +764,11 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, ] messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, { "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ + "content": "Can you tell me what the current weather is in Berlin and the " "forecast for the next 5 days, in fahrenheit?", }, ] @@ -867,64 +802,66 @@ async def test_required_tool_use(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, - sample_json_schema): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] +async def test_inconsistent_tool_choice_and_tools( + client: openai.AsyncOpenAI, sample_json_schema +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tool_choice={ - "type": "function", - "function": { - "name": - "dummy_function_name" - } - }) + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tool_choice={ + "type": "function", + "function": {"name": "dummy_function_name"}, + }, + ) with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }], + ], tool_choice={ "type": "function", - "function": { - "name": "nondefined_function_name" - } - }) + "function": {"name": "nondefined_function_name"}, + }, + ) with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }], - tool_choice={}) + ], + tool_choice={}, + ) @pytest.mark.asyncio @@ -932,13 +869,17 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') - }], - response_format={"type": "json_object"}) + messages=[ + { + "role": "user", + "content": ( + "what is 1+1? please respond with a JSON object, " + 'the format is {"result": 2}' + ), + } + ], + response_format={"type": "json_object"}, + ) content = resp.choices[0].message.content assert content is not None @@ -954,10 +895,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], ) content = resp.choices[0].message.content assert content is not None @@ -968,10 +906,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], response_format={ "type": "json_schema", "json_schema": { @@ -979,13 +914,12 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): "schema": { "type": "object", "properties": { - "result": { - "type": "integer" - }, + "result": {"type": "integer"}, }, }, - } - }) + }, + }, + ) content = resp.choices[0].message.content assert content is not None @@ -998,13 +932,16 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): async def test_extra_fields_allowed(client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?", - "extra_field": "0", - }], # type: ignore + messages=[ + { + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content is not None @@ -1014,18 +951,20 @@ async def test_extra_fields_allowed(client: openai.AsyncOpenAI): async def test_complex_message_content(client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": [{ - "type": - "text", - "text": - "what is 1+1? please provide the result without any other text." - }] - }], + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "what is 1+1? please provide the result without any other text.", + } + ], + } + ], temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content == "2" @@ -1037,24 +976,27 @@ async def test_custom_role(client: openai.AsyncOpenAI): resp1 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": "what is 1+1?", - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": "what is 1+1?", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) resp2 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": [{ - "type": "text", - "text": "what is 1+1?" - }] - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": [{"type": "text", "text": "what is 1+1?"}], + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content1 = resp1.choices[0].message.content content2 = resp2.choices[0].message.content @@ -1063,22 +1005,24 @@ async def test_custom_role(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_long_seed(client: openai.AsyncOpenAI): - for seed in [ - torch.iinfo(torch.long).min - 1, - torch.iinfo(torch.long).max + 1 - ]: + for seed in [torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).max + 1]: with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - }], + messages=[ + { + "role": "system", + "content": "You are a helpful assistant.", + } + ], temperature=0, - seed=seed) + seed=seed, + ) - assert ("greater_than_equal" in exc_info.value.message - or "less_than_equal" in exc_info.value.message) + assert ( + "greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message + ) @pytest.mark.asyncio @@ -1089,15 +1033,11 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): } data = { # model_name is avoided here. - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "what is 1+1?" - }], - "max_tokens": - 5 + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "what is 1+1?"}, + ], + "max_tokens": 5, } response = requests.post(url, headers=headers, json=data) @@ -1122,10 +1062,7 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): base_url=openai_api_base, ) messages = [ - { - "role": "user", - "content": "Hello, vLLM!" - }, + {"role": "user", "content": "Hello, vLLM!"}, ] response = client.chat.completions.create( model="", # empty string @@ -1135,15 +1072,11 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] request_args = { "model": MODEL_NAME, @@ -1155,8 +1088,9 @@ async def test_invocations(server: RemoteOpenAIServer, chat_completion = await client.chat.completions.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_completion.model_dump() diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index de63f4ed218b..c9c42efd2849 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -44,27 +44,26 @@ class TestCase(NamedTuple): "test_case", [ TestCase(model_name=MODEL_NAME, echo=True), - TestCase(model_name=MODEL_NAME, echo=False) + TestCase(model_name=MODEL_NAME, echo=False), ], ) async def test_chat_session_with_echo_and_continue_final_message( - client: openai.AsyncOpenAI, test_case: TestCase): + client: openai.AsyncOpenAI, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" # test echo with continue_final_message parameter chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], extra_body={ "echo": test_case.echo, "continue_final_message": True, - "add_generation_prompt": False - }) + "add_generation_prompt": False, + }, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 diff --git a/tests/entrypoints/openai/test_chat_logit_bias_validation.py b/tests/entrypoints/openai/test_chat_logit_bias_validation.py index e9d1a855294c..936875a25e0e 100644 --- a/tests/entrypoints/openai/test_chat_logit_bias_validation.py +++ b/tests/entrypoints/openai/test_chat_logit_bias_validation.py @@ -53,10 +53,7 @@ async def test_chat_logit_bias_valid(client): completion = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing valid logit bias" - }], + messages=[{"role": "user", "content": "Testing valid logit bias"}], max_tokens=5, logit_bias={str(valid_token_id): 1.0}, ) @@ -73,10 +70,7 @@ async def test_chat_logit_bias_invalid(client): with pytest.raises(openai.BadRequestError) as excinfo: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing invalid logit bias" - }], + messages=[{"role": "user", "content": "Testing invalid logit bias"}], max_tokens=5, logit_bias={str(invalid_token_id): 1.0}, ) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 6e32887f5ed0..17e36444fe84 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -4,8 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (apply_hf_chat_template, - load_chat_template) +from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -17,48 +16,54 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATION_OUTPUT = [ - ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user + ( + "facebook/opt-125m", + chatml_jinja_path, + True, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -"""), - ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user +""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user -What is the capital of"""), - ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user +What is the capital of""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + True, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -The capital of"""), +The capital of""", + ), ] TEST_MESSAGES = [ - { - 'role': 'user', - 'content': 'Hello' - }, - { - 'role': 'assistant', - 'content': 'Hi there!' - }, - { - 'role': 'user', - 'content': 'What is the capital of' - }, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What is the capital of"}, ] -ASSISTANT_MESSAGE_TO_CONTINUE = { - 'role': 'assistant', - 'content': 'The capital of' -} +ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"} def test_load_chat_template(): @@ -68,8 +73,11 @@ def test_load_chat_template(): # Test assertions assert template_content is not None # Hard coded value for template_chatml.jinja - assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} -{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 + assert ( + template_content + == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" + ) # noqa: E501 def test_no_load_chat_template_filelike(): @@ -91,9 +99,11 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( "model,template,add_generation_prompt,continue_final_message,expected_output", - MODEL_TEMPLATE_GENERATION_OUTPUT) -def test_get_gen_prompt(model, template, add_generation_prompt, - continue_final_message, expected_output): + MODEL_TEMPLATE_GENERATION_OUTPUT, +) +def test_get_gen_prompt( + model, template, add_generation_prompt, continue_final_message, expected_output +): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -116,7 +126,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, mock_request = ChatCompletionRequest( model=model, messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] - if continue_final_message else TEST_MESSAGES, + if continue_final_message + else TEST_MESSAGES, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) @@ -135,4 +146,5 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Test assertion assert result == expected_output, ( f"The generated prompt does not match the expected output for " - f"model {model} and template {template}") + f"model {model} and template {template}" + ) diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py index 03730b67283c..4f23eee46211 100644 --- a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -14,9 +14,14 @@ @pytest.fixture(scope="module") def server(): # noqa: F811 args = [ - "--max-model-len", "8192", "--enforce-eager", "--reasoning-parser", - "deepseek_r1", "--enable-auto-tool-choice", "--tool-call-parser", - "hermes" + "--max-model-len", + "8192", + "--enforce-eager", + "--reasoning-parser", + "deepseek_r1", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -29,50 +34,44 @@ async def client(server): yield async_client -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] - -MESSAGES = [{ - "role": "user", - "content": "Hi! How are you doing today?" -}, { - "role": "assistant", - "content": "I'm doing well! How can I help you?" -}, { - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] +] + +MESSAGES = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + { + "role": "user", + "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?", + }, +] FUNC_NAME = "get_current_weather" FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}""" @@ -105,9 +104,7 @@ def extract_reasoning_and_calls(chunks: list): # test streaming @pytest.mark.asyncio -async def test_chat_streaming_of_tool_and_reasoning( - client: openai.AsyncOpenAI): - +async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -120,8 +117,7 @@ async def test_chat_streaming_of_tool_and_reasoning( async for chunk in stream: chunks.append(chunk) - reasoning_content, arguments, function_names = extract_reasoning_and_calls( - chunks) + reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks) assert len(reasoning_content) > 0 assert len(function_names) > 0 and function_names[0] == FUNC_NAME assert len(arguments) > 0 and arguments[0] == FUNC_ARGS @@ -130,7 +126,6 @@ async def test_chat_streaming_of_tool_and_reasoning( # test full generate @pytest.mark.asyncio async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): - tool_calls = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -140,7 +135,5 @@ async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): ) assert len(tool_calls.choices[0].message.reasoning_content) > 0 - assert tool_calls.choices[0].message.tool_calls[0].function.name \ - == FUNC_NAME - assert tool_calls.choices[0].message.tool_calls[0].function.arguments \ - == FUNC_ARGS + assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME + assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS diff --git a/tests/entrypoints/openai/test_chunked_prompt.py b/tests/entrypoints/openai/test_chunked_prompt.py index 3c8ed955a65a..b248ffa23e6f 100644 --- a/tests/entrypoints/openai/test_chunked_prompt.py +++ b/tests/entrypoints/openai/test_chunked_prompt.py @@ -42,7 +42,8 @@ async def client(server): @pytest.mark.asyncio async def test_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt prompt = "What is the capital of France?" * 400 @@ -64,8 +65,9 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: tokens_received += 1 assert chunk.choices[0].text @@ -79,15 +81,13 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( @pytest.mark.asyncio async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" * 400 - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?" * 400}, + ] stream = await client.chat.completions.create( model=MODEL_NAME, messages=messages, @@ -108,8 +108,9 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: if chunk.choices[0].delta.content == "": diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index b2472658ca81..db2689188130 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -27,21 +27,16 @@ def server(): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_single_input_classification(server: RemoteOpenAIServer, - model_name: str): +def test_single_input_classification(server: RemoteOpenAIServer, model_name: str): input_text = "This product was excellent and exceeded my expectations" classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_text - }, + json={"model": model_name, "input": input_text}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert output.model == MODEL_NAME @@ -51,8 +46,7 @@ def test_single_input_classification(server: RemoteOpenAIServer, @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_multiple_inputs_classification(server: RemoteOpenAIServer, - model_name: str): +def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): input_texts = [ "The product arrived on time and works perfectly", "I'm very satisfied with my purchase, would buy again", @@ -64,13 +58,9 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_texts - }, + json={"model": model_name, "input": input_texts}, ) - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == len(input_texts) for i, item in enumerate(output.data): @@ -87,16 +77,11 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": long_text, - "truncate_prompt_tokens": 5 - }, + json={"model": model_name, "input": long_text, "truncate_prompt_tokens": 5}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == 1 assert output.data[0].index == 0 @@ -106,15 +91,12 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, - model_name: str): +def test_invalid_truncate_prompt_tokens_error( + server: RemoteOpenAIServer, model_name: str +): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "test", - "truncate_prompt_tokens": 513 - }, + json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513}, ) error = classification_response.json() @@ -127,10 +109,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "" - }, + json={"model": model_name, "input": ""}, ) error = classification_response.json() @@ -139,18 +118,13 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_batch_classification_empty_list(server: RemoteOpenAIServer, - model_name: str): +def test_batch_classification_empty_list(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": [] - }, + json={"model": model_name, "input": []}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert isinstance(output.data, list) @@ -161,15 +135,17 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, async def test_invocations(server: RemoteOpenAIServer): request_args = { "model": MODEL_NAME, - "input": "This product was excellent and exceeded my expectations" + "input": "This product was excellent and exceeded my expectations", } - classification_response = requests.post(server.url_for("classify"), - json=request_args) + classification_response = requests.post( + server.url_for("classify"), json=request_args + ) classification_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() classification_output = classification_response.json() @@ -177,7 +153,9 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_output.keys() == invocation_output.keys() for classification_data, invocation_data in zip( - classification_output["data"], invocation_output["data"]): + classification_output["data"], invocation_output["data"] + ): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( - invocation_data["probs"], rel=0.01) + invocation_data["probs"], rel=0.01 + ) diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 504fd72aa4ae..c6e6bf0be1f8 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -5,8 +5,7 @@ import pytest -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.utils import FlexibleArgumentParser @@ -15,7 +14,7 @@ LORA_MODULE = { "name": "module2", "path": "/path/to/module2", - "base_model_name": "llama" + "base_model_name": "llama", } CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja" assert CHATML_JINJA_PATH.exists() @@ -30,24 +29,26 @@ def serve_parser(): ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - ]) - expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + ] + ) + expected = [LoRAModulePath(name="module1", path="/path/to/module1")] assert args.lora_modules == expected def test_valid_json_format(serve_parser): # Test valid JSON format input - args = serve_parser.parse_args([ - '--lora-modules', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module2", path="/path/to/module2", base_model_name="llama") ] assert args.lora_modules == expected @@ -55,47 +56,53 @@ def test_valid_json_format(serve_parser): def test_invalid_json_format(serve_parser): # Test invalid JSON format input, missing closing brace with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', '{"name": "module3", "path": "/path/to/module3"' - ]) + serve_parser.parse_args( + ["--lora-modules", '{"name": "module3", "path": "/path/to/module3"'] + ) def test_invalid_type_error(serve_parser): # Test type error when values are not JSON or key=value with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - 'invalid_format' # This is not JSON or key=value format - ]) + serve_parser.parse_args( + [ + "--lora-modules", + "invalid_format", # This is not JSON or key=value format + ] + ) def test_invalid_json_field(serve_parser): # Test valid JSON format but missing required fields with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - '{"name": "module4"}' # Missing required 'path' field - ]) + serve_parser.parse_args( + [ + "--lora-modules", + '{"name": "module4"}', # Missing required 'path' field + ] + ) def test_empty_values(serve_parser): # Test when no LoRA modules are provided - args = serve_parser.parse_args(['--lora-modules', '']) + args = serve_parser.parse_args(["--lora-modules", ""]) assert args.lora_modules == [] def test_multiple_valid_inputs(serve_parser): # Test multiple valid inputs (both old and JSON format) - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module1', path='/path/to/module1'), - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module1", path="/path/to/module1"), + LoRAModulePath( + name="module2", path="/path/to/module2", base_model_name="llama" + ), ] assert args.lora_modules == expected @@ -111,40 +118,46 @@ def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser): def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): """Ensure validation passes with tool choice enabled with a call parser""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--tool-call-parser", - "mistral", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", + "mistral", + ] + ) validate_parsed_serve_args(args) def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): """Ensure validation fails if reasoning is enabled with auto tool choice""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--reasoning-parser", + "deepseek_r1", + ] + ) with pytest.raises(TypeError): validate_parsed_serve_args(args) def test_passes_with_reasoning_parser(serve_parser): - """Ensure validation passes if reasoning is enabled + """Ensure validation passes if reasoning is enabled with a reasoning parser""" - args = serve_parser.parse_args(args=[ - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--reasoning-parser", + "deepseek_r1", + ] + ) validate_parsed_serve_args(args) def test_chat_template_validation_for_happy_paths(serve_parser): """Ensure validation passes if the chat template exists""" args = serve_parser.parse_args( - args=["--chat-template", - CHATML_JINJA_PATH.absolute().as_posix()]) + args=["--chat-template", CHATML_JINJA_PATH.absolute().as_posix()] + ) validate_parsed_serve_args(args) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index df9586ee84de..23b032f67d8b 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -12,6 +12,7 @@ import pytest_asyncio import regex as re import requests + # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -47,8 +48,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Copy tokenizer to adapter and add some unique tokens # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], special_tokens=True) assert added == 3 tokenizer.save_pretrained(tmp_model_dir) yield tmp_model_dir @@ -61,8 +61,9 @@ def zephyr_pa_files(): @pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): +def default_server_args( + zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files +): return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -93,8 +94,7 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ] -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) +@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) def server(default_server_args, request): if request.param: default_server_args.append(request.param) @@ -112,16 +112,20 @@ async def client(server): @pytest.mark.parametrize( # first test base model, then test loras, then test prompt adapters "model_name,num_virtual_tokens", - [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), - ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], + [ + (MODEL_NAME, 0), + ("zephyr-lora", 0), + ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS), + ], ) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, - num_virtual_tokens: int): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) +async def test_single_completion( + client: openai.AsyncOpenAI, model_name: str, num_virtual_tokens: int +): + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -132,7 +136,8 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6 + num_virtual_tokens, - total_tokens=11 + num_virtual_tokens) + total_tokens=11 + num_virtual_tokens, + ) # test using token IDs completion = await client.completions.create( @@ -240,11 +245,12 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str): - +async def test_too_many_completion_logprobs( + client: openai.AsyncOpenAI, model_name: str +): with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -256,7 +262,8 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ) ... with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs stream = await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -281,13 +288,13 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), - (MODEL_NAME, 0), - (MODEL_NAME, 1), - (MODEL_NAME, None)]) -async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, -1), (MODEL_NAME, 0), (MODEL_NAME, 1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_completion( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, @@ -316,8 +323,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str): +async def test_completion_streaming(client: openai.AsyncOpenAI, model_name: str): prompt = "What is an LLM?" single_completion = await client.completions.create( @@ -327,11 +333,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, temperature=0.0, ) single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: @@ -360,11 +364,9 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): n = 3 max_tokens = 5 - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, stream=True + ) chunks: list[list[str]] = [[] for i in range(n)] finish_reason_count = 0 async for chunk in stream: @@ -384,53 +386,55 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): +async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): prompt = "What is the capital of France?" # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: assert chunk.usage is None @@ -441,57 +445,63 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is not None assert chunk.usage.prompt_tokens > 0 assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if chunk.choices[0].finish_reason is not None: final_chunk = await stream.__anext__() assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options= # {"include_usage": None} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options= # {"include_usage": True} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": None} @@ -502,7 +512,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": None}) + stream_options={"continuous_usage_stats": None}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": True} @@ -513,7 +524,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": True}) + stream_options={"continuous_usage_stats": True}, + ) @pytest.mark.asyncio @@ -544,15 +556,19 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): extra_body=dict( # NOTE: this has to be true for n > 1 in vLLM, but # not necessary for official client. - use_beam_search=True), + use_beam_search=True + ), ) assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" + assert batch.choices[0].text != batch.choices[1].text, ( + "beam search should be different" + ) + assert batch.choices[0].text == batch.choices[2].text, ( + "two copies of the same prompt should be the same" + ) + assert batch.choices[1].text == batch.choices[3].text, ( + "two copies of the same prompt should be the same" + ) # test streaming batch = await client.completions.create( @@ -587,14 +603,18 @@ async def test_logits_bias(client: openai.AsyncOpenAI): seed=42, ) assert len(completion.choices[0].text) >= 5 - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), - add_special_tokens=False)["input_ids"] - assert all([ - response == expected - for response, expected in zip(response_tokens, expected_tokens) - ]) + response_tokens = tokenizer(completion.choices[0].text, add_special_tokens=False)[ + "input_ids" + ] + expected_tokens = tokenizer( + tokenizer.decode([token_id] * 5), add_special_tokens=False + )["input_ids"] + assert all( + [ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ] + ) # Test ban completion = await client.completions.create( @@ -603,16 +623,16 @@ async def test_logits_bias(client: openai.AsyncOpenAI): max_tokens=max_tokens, temperature=0.0, ) - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] + response_tokens = tokenizer(completion.choices[0].text, add_special_tokens=False)[ + "input_ids" + ] first_response = completion.choices[0].text completion = await client.completions.create( model=MODEL_NAME, prompt=prompt, max_tokens=max_tokens, temperature=0.0, - logit_bias={str(token): -100 - for token in response_tokens}, + logit_bias={str(token): -100 for token in response_tokens}, ) assert first_response != completion.choices[0].text @@ -641,9 +661,9 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI): @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema): +async def test_guided_json_completion( + client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_json_schema +): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -651,8 +671,11 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI, n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict( + guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend, + ), + ) assert completion.id is not None assert len(completion.choices) == 3 @@ -663,38 +686,42 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex): +async def test_guided_regex_completion( + client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_regex +): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {sample_regex}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict( + guided_regex=sample_regex, guided_decoding_backend=guided_decoding_backend + ), + ) assert completion.id is not None assert len(completion.choices) == 3 for i in range(3): - assert re.fullmatch(sample_regex, - completion.choices[i].text) is not None + assert re.fullmatch(sample_regex, completion.choices[i].text) is not None @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice): +async def test_guided_choice_completion( + client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_guided_choice +): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict( + guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend, + ), + ) assert completion.id is not None assert len(completion.choices) == 2 @@ -703,21 +730,23 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements): - +async def test_guided_grammar(client: openai.AsyncOpenAI, sample_sql_statements): completion = await client.completions.create( model=MODEL_NAME, - prompt=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), + prompt=( + "Generate a sql state that select col_1 from " + "table_1 where it is equals to 1" + ), temperature=1.0, max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) + extra_body=dict(guided_grammar=sample_sql_statements), + ) content = completion.choices[0].text # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark + parser = Lark(sample_sql_statements) parser.parse(content) @@ -734,52 +763,56 @@ async def test_guided_grammar(client: openai.AsyncOpenAI, [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], ) @pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): +async def test_echo_logprob_completion( + client: openai.AsyncOpenAI, model_name: str, logprobs_arg: int +): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # test using text and token IDs for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) - - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt + completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg, + ) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt assert re.search(r"^" + prompt_text, completion.choices[0].text) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) + assert len(logprobs.token_logprobs) > 5 and logprobs.token_logprobs[0] is None + assert len(logprobs.top_logprobs) > 5 and logprobs.top_logprobs[0] is None for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex): +async def test_guided_decoding_type_error( + client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, + sample_regex, +): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict( + guided_json=42, guided_decoding_backend=guided_decoding_backend + ), + ) with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) + extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema), + ) @pytest.mark.asyncio @@ -789,19 +822,21 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, (MODEL_NAME, False, False), (MODEL_NAME, False, True), (MODEL_NAME, True, False), - (MODEL_NAME, True, True) # should not raise BadRequestError error + (MODEL_NAME, True, True), # should not raise BadRequestError error ], ) -async def test_echo_stream_completion(client: openai.AsyncOpenAI, - model_name: str, stream: bool, - echo: bool): +async def test_echo_stream_completion( + client: openai.AsyncOpenAI, model_name: str, stream: bool, echo: bool +): saying: str = "Hello, my name is" - result = await client.completions.create(model=model_name, - prompt=saying, - max_tokens=10, - temperature=0.0, - echo=echo, - stream=stream) + result = await client.completions.create( + model=model_name, + prompt=saying, + max_tokens=10, + temperature=0.0, + echo=echo, + stream=stream, + ) stop_reason = "length" @@ -837,8 +872,7 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): request_args = { "model": MODEL_NAME, "prompt": "Hello, my name is", @@ -849,8 +883,9 @@ async def test_invocations(server: RemoteOpenAIServer, completion = await client.completions.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion.model_dump() diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index eca048d855b5..d6835db4c959 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -42,9 +42,13 @@ async def client(server): @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("tool_choice", ["auto", "required"]) @pytest.mark.parametrize("enable_thinking", [True, False]) -async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, - stream: bool, tool_choice: str, - enable_thinking: bool): +async def test_function_tool_use( + client: openai.AsyncOpenAI, + model_name: str, + stream: bool, + tool_choice: str, + enable_thinking: bool, +): tools = [ { "type": "function", @@ -56,26 +60,21 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "properties": { "city": { "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", + "description": "The city to find the weather for, e.g. 'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. 'Austria'", }, "unit": { "type": "string", - "description": - "The unit to fetch the temperature in", + "description": "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, "options": { "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", + "description": "Optional parameters for weather query", }, }, "required": ["country", "unit"], @@ -95,8 +94,7 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "include_forecast": { "type": "boolean", "default": False, - "description": - "Whether to include a 24-hour forecast", + "description": "Whether to include a 24-hour forecast", "title": "Include Forecast", }, "language": { @@ -122,26 +120,20 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "properties": { "city": { "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", + "description": "The city to get the forecast for, e.g. 'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. 'Austria'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, "unit": { "type": "string", - "description": - "The unit to fetch the temperature in", + "description": "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, }, @@ -152,19 +144,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ] messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, { "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ + "content": "Can you tell me what the current weather is in Berlin and the " "forecast for the next 5 days, in fahrenheit?", }, ] @@ -175,16 +159,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, model=model_name, tools=tools, tool_choice=tool_choice, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) if enable_thinking: - assert chat_completion.choices[0].message.\ - reasoning_content is not None - assert chat_completion.choices[0].message.\ - reasoning_content != "" + assert chat_completion.choices[0].message.reasoning_content is not None + assert chat_completion.choices[0].message.reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: @@ -195,11 +174,8 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, tools=tools, tool_choice=tool_choice, stream=True, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) output = [] async for chunk in output_stream: diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 00d3ffb61ee9..b2ae15cbf33b 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -10,6 +10,7 @@ import pytest import pytest_asyncio import torch + # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -37,8 +38,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Copy tokenizer to adapter and add some unique tokens # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], special_tokens=True) assert added == 3 tokenizer.save_pretrained(tmp_model_dir) yield tmp_model_dir @@ -65,8 +65,7 @@ def default_server_args( ] -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) +@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) def server_with_prompt_embeds(default_server_args, request): if request.param: default_server_args.append(request.param) @@ -86,13 +85,14 @@ def create_dummy_embeds(num_tokens: int = 5) -> str: dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) buffer = io.BytesIO() torch.save(dummy_embeds, buffer) - return base64.b64encode(buffer.getvalue()).decode('utf-8') + return base64.b64encode(buffer.getvalue()).decode("utf-8") @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_completions_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str +): # Test case: Single prompt embeds input encoded_embeds = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( @@ -100,7 +100,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices[0].text) >= 1 assert completion.choices[0].prompt_logprobs is None @@ -111,7 +112,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 assert len(completion.choices[0].text) >= 1 assert len(completion.choices[1].text) >= 1 @@ -123,7 +125,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) single_output = single_completion.choices[0].text stream = await client_with_prompt_embeds.completions.create( @@ -132,7 +135,8 @@ async def test_completions_with_prompt_embeds( max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) chunks = [] finish_reason_count = 0 async for chunk in stream: @@ -152,12 +156,12 @@ async def test_completions_with_prompt_embeds( max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) chunks_stream_embeds: list[list[str]] = [[], []] finish_reason_count = 0 async for chunk in stream: - chunks_stream_embeds[chunk.choices[0].index].append( - chunk.choices[0].text) + chunks_stream_embeds[chunk.choices[0].index].append(chunk.choices[0].text) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 assert finish_reason_count == 2 @@ -173,7 +177,8 @@ async def test_completions_with_prompt_embeds( prompt="This is a prompt", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices) == 2 completion_text_only = await client_with_prompt_embeds.completions.create( model=model_name, @@ -186,18 +191,18 @@ async def test_completions_with_prompt_embeds( prompt="", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) # Embeddings responses should be handled first - assert completion_mixed.choices[0].text == completion_embeds_only.choices[ - 0].text - assert completion_mixed.choices[1].text == completion_text_only.choices[ - 0].text + assert completion_mixed.choices[0].text == completion_embeds_only.choices[0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[0].text @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_completions_errors_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str +): # Test error case: invalid prompt_embeds with pytest.raises(BadRequestError): await client_with_prompt_embeds.completions.create( @@ -205,15 +210,16 @@ async def test_completions_errors_with_prompt_embeds( model=model_name, max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": "invalid_base64"}) + extra_body={"prompt_embeds": "invalid_base64"}, + ) @pytest.mark.asyncio @pytest.mark.parametrize("logprobs_arg", [1, 0]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_completions_with_logprobs_and_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, - model_name: str): + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, model_name: str +): # Test case: Logprobs using prompt_embeds encoded_embeds = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( @@ -223,7 +229,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) logprobs = completion.choices[0].logprobs assert logprobs is not None @@ -243,7 +250,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 for choice in completion.choices: @@ -253,6 +261,5 @@ async def test_completions_with_logprobs_and_prompt_embeds( assert len(logprobs.token_logprobs) == 5 assert len(logprobs.top_logprobs) == 5 for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) == 5 diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index 1fc87c8b42a7..237e921016b6 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -16,8 +16,7 @@ # need a multimodal model for these tests. # Contains a modality specific lora alongside the base model -MULTIMODAL_MODEL_NAME = snapshot_download( - "microsoft/Phi-4-multimodal-instruct") +MULTIMODAL_MODEL_NAME = snapshot_download("microsoft/Phi-4-multimodal-instruct") AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora") ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 @@ -26,6 +25,7 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @@ -33,9 +33,8 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[False, True]) def multimodal_server(request, monkeypatch_module): # noqa: F811 - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') + monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") args = [ # use half precision for speed and memory savings in CI environment @@ -56,7 +55,7 @@ def multimodal_server(request, monkeypatch_module): # noqa: F811 "--gpu-memory-utilization", "0.8", "--default-mm-loras", - f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", + f'{{"audio": "{AUDIO_LORA_PATH}"}}', ] with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server: @@ -80,25 +79,25 @@ async def test_default_mm_lora_chat_completions( multi_modal_client: openai.AsyncOpenAI, audio_assets: AudioTestAssets, ): - messages = [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "Can you transcribe this audio?", - }, { - "type": "audio_url", - "audio_url": { - "url": audio_assets[0].url - }, - }] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you transcribe this audio?", + }, + { + "type": "audio_url", + "audio_url": {"url": audio_assets[0].url}, + }, + ], + } + ] chat_completion = await multi_modal_client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=128, - temperature=0.0) + model=model_name, messages=messages, max_completion_tokens=128, temperature=0.0 + ) assert len(chat_completion.choices) > 0 diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index f03c96b12179..eb9786b3677d 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -12,8 +12,7 @@ from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from ...models.language.pooling.embed_utils import run_embedding_correctness_test from ...models.utils import check_embeddings_close from ...utils import RemoteOpenAIServer @@ -57,15 +56,13 @@ async def client(server): @pytest.fixture(scope="module") def hf_model(hf_runner): - with hf_runner(MODEL_NAME, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner(MODEL_NAME, dtype=DTYPE, is_sentence_transformer=True) as hf_model: yield hf_model @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] @@ -77,7 +74,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -97,7 +95,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -109,12 +108,12 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] embedding_response = await client.embeddings.create( model=model_name, @@ -122,7 +121,8 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -135,15 +135,20 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] embedding_response = await client.embeddings.create( model=model_name, input=input_tokens, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 4 @@ -155,19 +160,23 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_embedding(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_embedding( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("v1/embeddings"), @@ -196,64 +205,66 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, extra_body={"add_special_tokens": False}, ) completion_embeddings = EmbeddingResponse.model_validate( - completion_response.model_dump(mode="json")) + completion_response.model_dump(mode="json") + ) assert chat_embeddings.id is not None assert completion_embeddings.id is not None assert chat_embeddings.created <= completion_embeddings.created - assert chat_embeddings.model_dump( - exclude={"id", "created"}) == (completion_embeddings.model_dump( - exclude={"id", "created"})) + assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( + completion_embeddings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_base64_embedding( + hf_model, client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] - responses_float = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="float") + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) float_data = [d.embedding for d in responses_float.data] run_embedding_correctness_test(hf_model, input_texts, float_data) - responses_base64 = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="base64") + responses_base64 = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="base64" + ) base64_data = [] for data in responses_base64.data: base64_data.append( - np.frombuffer(base64.b64decode(data.embedding), - dtype="float32").tolist()) + np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist() + ) run_embedding_correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client - responses_default = await client.embeddings.create(input=input_texts, - model=model_name) + responses_default = await client.embeddings.create( + input=input_texts, model=model_name + ) default_data = [d.embedding for d in responses_default.data] run_embedding_correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] # test single embedding embedding_response = await client.embeddings.create( - model=model_name, - input=input_texts, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -263,15 +274,34 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 10 input_tokens = [ - 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, - 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 + 1, + 24428, + 289, + 18341, + 26165, + 285, + 19323, + 283, + 289, + 26789, + 3871, + 28728, + 9901, + 340, + 2229, + 385, + 340, + 315, + 28741, + 28804, + 2, ] embedding_response = await client.embeddings.create( - model=model_name, - input=input_tokens, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -283,8 +313,9 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation_invalid( + client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] @@ -293,15 +324,17 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, response = await client.embeddings.create( model=model_name, input=input_texts, - extra_body={"truncate_prompt_tokens": 8193}) + extra_body={"truncate_prompt_tokens": 8193}, + ) assert "error" in response.object - assert "truncate_prompt_tokens value is greater than max_model_len. "\ - "Please, select a smaller truncation size." in response.message + assert ( + "truncate_prompt_tokens value is greater than max_model_len. " + "Please, select a smaller truncation size." in response.message + ) @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): input_texts = [ "The chef prepared a delicious meal.", ] @@ -314,35 +347,43 @@ async def test_invocations(server: RemoteOpenAIServer, completion_response = await client.embeddings.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.model_dump() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[completion_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -350,22 +391,25 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): "encoding_format": "float", } - chat_response = requests.post(server.url_for("v1/embeddings"), - json=request_args) + chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[chat_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="chat", + name_1="invocation", + ) diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py index 08b797dc57ad..05c2b5dcc471 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -12,16 +12,17 @@ from vllm.entrypoints.openai.protocol import EmbeddingResponse from ...conftest import HfRunner -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from ...models.language.pooling.embed_utils import run_embedding_correctness_test from ...models.utils import EmbedModelInfo from ...utils import RemoteOpenAIServer MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - matryoshka_dimensions=[256]), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256], + ), ] input_texts = [ @@ -49,15 +50,14 @@ def server(model_info, dtype: str): dtype, "--enforce-eager", "--max-model-len", - "512" + "512", ] if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": # Manually enable Matryoshka Embeddings - args.extend([ - "--trust_remote_code", "--hf_overrides", - '{"matryoshka_dimensions":[256]}' - ]) + args.extend( + ["--trust_remote_code", "--hf_overrides", '{"matryoshka_dimensions":[256]}'] + ) with RemoteOpenAIServer(model_info.name, args) as remote_server: yield remote_server @@ -65,14 +65,16 @@ def server(model_info, dtype: str): @pytest.fixture(scope="module") def hf_model(hf_runner, model_info, dtype: str): - with hf_runner(model_info.name, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner( + model_info.name, dtype=dtype, is_sentence_transformer=True + ) as hf_model: yield hf_model @pytest.mark.asyncio -async def test_matryoshka(model_info: EmbedModelInfo, - server: RemoteOpenAIServer, hf_model: HfRunner): +async def test_matryoshka( + model_info: EmbedModelInfo, server: RemoteOpenAIServer, hf_model: HfRunner +): client = server.get_async_client() async def make_request_and_correctness_test(dimensions): @@ -85,7 +87,8 @@ async def make_request_and_correctness_test(dimensions): encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -98,8 +101,7 @@ async def make_request_and_correctness_test(dimensions): assert len(embeddings.data[0].embedding) == dimensions vllm_outputs = [d.embedding for d in embeddings.data] - run_embedding_correctness_test(hf_model, prompts, vllm_outputs, - dimensions) + run_embedding_correctness_test(hf_model, prompts, vllm_outputs, dimensions) if model_info.is_matryoshka: valid_dimensions: list[Optional[int]] = [None] diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py index 9c2aef23e877..c68226409550 100644 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -31,10 +31,9 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -43,7 +42,8 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=2, total_tokens=7) + completion_tokens=5, prompt_tokens=2, total_tokens=7 + ) # test using token IDs completion = await client.completions.create( diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index bcdeaaacedea..5f54a881387b 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -9,6 +9,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio + # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -23,26 +24,18 @@ BADREQUEST_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), ( "test_bias", - { - "bias": "all" - }, + {"bias": "all"}, "Adapter bias cannot be used without bias_enabled", ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] @@ -56,29 +49,28 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @pytest.fixture(scope="module", params=[False, True]) -def server_with_lora_modules_json(request, monkeypatch_module, - zephyr_lora_files): - +def server_with_lora_modules_json(request, monkeypatch_module, zephyr_lora_files): use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') + monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") # Define the json format LoRA module configurations lora_module_1 = { "name": "zephyr-lora", "path": zephyr_lora_files, - "base_model_name": MODEL_NAME + "base_model_name": MODEL_NAME, } lora_module_2 = { "name": "zephyr-lora2", "path": zephyr_lora_files, - "base_model_name": MODEL_NAME + "base_model_name": MODEL_NAME, } args = [ @@ -110,14 +102,12 @@ def server_with_lora_modules_json(request, monkeypatch_module, @pytest_asyncio.fixture async def client(server_with_lora_modules_json): - async with server_with_lora_modules_json.get_async_client( - ) as async_client: + async with server_with_lora_modules_json.get_async_client() as async_client: yield async_client @pytest.mark.asyncio -async def test_static_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): +async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] @@ -125,23 +115,19 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME assert served_model.parent is None - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" @pytest.mark.asyncio -async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): - - response = await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "zephyr-lora-3", - "lora_path": zephyr_lora_files - }) +async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): + response = await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files}, + ) # Ensure adapter loads before querying /models assert "success" in response @@ -156,37 +142,37 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI): with pytest.raises(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) @pytest.mark.asyncio -async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, - tmp_path): +async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") with pytest.raises(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid-json", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid-json", "lora_path": str(invalid_files)}, + ) @pytest.mark.asyncio -@pytest.mark.parametrize("test_name,config_change,expected_error", - BADREQUEST_CASES) -async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files, test_name: str, - config_change: dict, - expected_error: str): +@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES) +async def test_dynamic_lora_badrequests( + client: openai.AsyncOpenAI, + tmp_path, + zephyr_lora_files, + test_name: str, + config_change: dict, + expected_error: str, +): # Create test directory test_dir = tmp_path / test_name @@ -206,29 +192,28 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, # Test loading the adapter with pytest.raises(openai.BadRequestError, match=expected_error): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": test_name, - "lora_path": str(test_dir) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": test_name, "lora_path": str(test_dir)}, + ) @pytest.mark.asyncio -async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files): - """Validate that many loras can be dynamically registered and inferenced +async def test_multiple_lora_adapters( + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): + """Validate that many loras can be dynamically registered and inferenced with concurrently""" # This test file configures the server with --max-cpu-loras=2 and this test # will concurrently load 10 adapters, so it should flex the LRU cache async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -238,8 +223,7 @@ async def load_and_run_adapter(adapter_name: str): lora_tasks = [] for i in range(10): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) @@ -249,8 +233,8 @@ async def load_and_run_adapter(adapter_name: str): @pytest.mark.asyncio async def test_loading_invalid_adapters_does_not_break_others( - client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files): - + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") @@ -281,20 +265,18 @@ async def run_good_requests(client): # Run a bunch of bad adapter loads for _ in range(25): with suppress(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) for _ in range(25): with suppress(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid", "lora_path": str(invalid_files)}, + ) # Ensure all the running requests with lora adapters succeeded stop_good_requests_event.set() @@ -303,12 +285,11 @@ async def run_good_requests(client): assert not isinstance(r, Exception), f"Got exception {r}" # Ensure we can load another adapter and run it - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "valid", - "lora_path": zephyr_lora_files - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "valid", "lora_path": zephyr_lora_files}, + ) await client.completions.create( model="valid", prompt=["Hello there", "Foo bar bazz buzz"], @@ -325,12 +306,11 @@ async def test_beam_search_with_lora_adapters( """Validate that async beam search can be used with lora.""" async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -341,8 +321,7 @@ async def load_and_run_adapter(adapter_name: str): lora_tasks = [] for i in range(3): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index d4afdf7751c8..6b91552c4565 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -13,8 +13,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.transformers_utils.tokenizer import get_tokenizer @@ -33,14 +32,14 @@ class MockHFConfig: @dataclass class MockModelConfig: """Minimal mock ModelConfig for testing.""" + model: str = MODEL_NAME tokenizer: str = MODEL_NAME trust_remote_code: bool = False tokenizer_mode: str = "auto" max_model_len: int = 100 tokenizer_revision: Optional[str] = None - multimodal_config: MultiModalConfig = field( - default_factory=MultiModalConfig) + multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) logits_processor_pattern: Optional[str] = None diff_sampling_param: Optional[dict] = None @@ -53,17 +52,21 @@ def get_diff_sampling_param(self): class MockLoRAResolver(LoRAResolver): - - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test-lora": - return LoRARequest(lora_name="test-lora", - lora_int_id=1, - lora_local_path="/fake/path/test-lora") + return LoRARequest( + lora_name="test-lora", + lora_int_id=1, + lora_local_path="/fake/path/test-lora", + ) elif lora_name == "invalid-lora": - return LoRARequest(lora_name="invalid-lora", - lora_int_id=2, - lora_local_path="/fake/path/invalid-lora") + return LoRARequest( + lora_name="invalid-lora", + lora_int_id=2, + lora_local_path="/fake/path/invalid-lora", + ) return None @@ -92,29 +95,28 @@ def mock_add_lora_side_effect(lora_request: LoRARequest): return elif lora_request.lora_name == "invalid-lora": # Simulate failure during addition (e.g. invalid format) - raise ValueError(f"Simulated failure adding LoRA: " - f"{lora_request.lora_name}") + raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}") mock_engine.add_lora.side_effect = mock_add_lora_side_effect mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() mock_model_config = MockModelConfig() - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) - serving_completion = OpenAIServingCompletion(mock_engine, - mock_model_config, - models, - request_logger=None) + serving_completion = OpenAIServingCompletion( + mock_engine, mock_model_config, models, request_logger=None + ) return mock_engine, serving_completion @pytest.mark.asyncio -async def test_serving_completion_with_lora_resolver(mock_serving_setup, - monkeypatch): +async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -136,14 +138,13 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup, assert called_lora_request.lora_name == lora_model_name mock_engine.generate.assert_called_once() - called_lora_request = mock_engine.generate.call_args[1]['lora_request'] + called_lora_request = mock_engine.generate.call_args[1]["lora_request"] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name @pytest.mark.asyncio -async def test_serving_completion_resolver_not_found(mock_serving_setup, - monkeypatch): +async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -166,7 +167,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, @pytest.mark.asyncio async def test_serving_completion_resolver_add_lora_fails( - mock_serving_setup, monkeypatch): + mock_serving_setup, monkeypatch +): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 2d7b845736b8..c47065df4c34 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -54,19 +54,22 @@ def default_server_args(): ] -@pytest.fixture(scope="module", - params=[ - "", - "--enable-chunked-prefill", - "--disable-frontend-multiprocessing", - f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", - ]) +@pytest.fixture( + scope="module", + params=[ + "", + "--enable-chunked-prefill", + "--disable-frontend-multiprocessing", + f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", + ], +) def server(use_v1, default_server_args, request): if request.param: default_server_args.append(request.param) - env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0') - with RemoteOpenAIServer(MODEL_NAME, default_server_args, - env_dict=env_dict) as remote_server: + env_dict = dict(VLLM_USE_V1="1" if use_v1 else "0") + with RemoteOpenAIServer( + MODEL_NAME, default_server_args, env_dict=env_dict + ) as remote_server: yield remote_server @@ -87,30 +90,36 @@ async def client(server): # {metric_family: [(suffix, expected_value)]} EXPECTED_VALUES = { "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": - [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], + "vllm:time_per_output_token_seconds": [ + ("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) + ], "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": - [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_generation_tokens": - [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], + "vllm:request_prompt_tokens": [ + ("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS), + ], + "vllm:request_generation_tokens": [ + ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS), + ], "vllm:request_params_n": [("_count", _NUM_REQUESTS)], "vllm:request_params_max_tokens": [ ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS) + ("_count", _NUM_REQUESTS), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + _NUM_REQUESTS + * (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), + ), + ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ], - "vllm:iteration_tokens_total": - [("_sum", _NUM_REQUESTS * - (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], - "vllm:prompt_tokens": [("_total", - _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], + "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], "vllm:generation_tokens": [ ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) ], @@ -119,14 +128,16 @@ async def client(server): @pytest.mark.asyncio -async def test_metrics_counts(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_counts( + server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool +): for _ in range(_NUM_REQUESTS): # sending a request triggers the metrics to be logged. await client.completions.create( model=MODEL_NAME, prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) + max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, + ) response = requests.get(server.url_for("metrics")) print(response.text) @@ -134,9 +145,10 @@ async def test_metrics_counts(server: RemoteOpenAIServer, # Loop over all expected metric_families for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if ((use_v1 and metric_family not in EXPECTED_METRICS_V1) - or (not server.show_hidden_metrics - and metric_family in HIDDEN_DEPRECATED_METRICS)): + if (use_v1 and metric_family not in EXPECTED_METRICS_V1) or ( + not server.show_hidden_metrics + and metric_family in HIDDEN_DEPRECATED_METRICS + ): continue found_metric = False @@ -160,14 +172,15 @@ async def test_metrics_counts(server: RemoteOpenAIServer, assert sample.value == expected_value, ( f"{metric_name_w_suffix} expected value of " f"{expected_value} did not match found value " - f"{sample.value}") + f"{sample.value}" + ) break assert found_suffix, ( f"Did not find {metric_name_w_suffix} in prom endpoint" ) break - assert found_metric, (f"Did not find {metric_family} in prom endpoint") + assert found_metric, f"Did not find {metric_family} in prom endpoint" EXPECTED_METRICS = [ @@ -277,20 +290,19 @@ async def test_metrics_counts(server: RemoteOpenAIServer, @pytest.mark.asyncio -async def test_metrics_exist(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_exist( + server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool +): # sending a request triggers the metrics to be logged. - await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + await client.completions.create( + model=MODEL_NAME, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): - if (not server.show_hidden_metrics - and metric not in HIDDEN_DEPRECATED_METRICS): + for metric in EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS: + if not server.show_hidden_metrics and metric not in HIDDEN_DEPRECATED_METRICS: assert metric in response.text @@ -303,27 +315,30 @@ def test_metrics_exist_run_batch(use_v1: bool): port = "8001" server_url = f"http://{base_url}:{port}" - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - sys.executable, - "-m", - "vllm.entrypoints.openai.run_batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "intfloat/multilingual-e5-small", - "--enable-metrics", - "--url", - base_url, - "--port", - port, - ], ) + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.openai.run_batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + "--enable-metrics", + "--url", + base_url, + "--port", + port, + ], + ) def is_server_up(url): try: diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 1980daa80db9..9444a8a677b0 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -4,6 +4,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio + # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -61,7 +62,6 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): lora_models = models[1:] assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index f0ce50debe49..ba463be1d5cd 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -25,13 +25,10 @@ def run_and_test_dummy_opt_api_server(model, tp=1): client = server.get_client() completion = client.chat.completions.create( model=model, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], temperature=0, ) generated_text = completion.choices[0].message.content diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 580bf34f20c4..c715e9246792 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -70,11 +70,15 @@ def no_file_type(case: schemathesis.models.Case): -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ http://localhost:8000/tokenize """ # noqa: E501 - if (op.method.lower() == "post" and op.path == "/tokenize" - and hasattr(case, "body") and isinstance(case.body, dict) - and "messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): + if ( + op.method.lower() == "post" + and op.path == "/tokenize" + and hasattr(case, "body") + and isinstance(case.body, dict) + and "messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0 + ): for message in case.body["messages"]: if not isinstance(message, dict): continue @@ -102,9 +106,8 @@ def test_openapi_stateless(case: schemathesis.Case): timeout = { # requires a longer timeout - ("POST", "/v1/chat/completions"): - LONG_TIMEOUT_SECONDS, + ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, }.get(key, DEFAULT_TIMEOUT_SECONDS) - #No need to verify SSL certificate for localhost + # No need to verify SSL certificate for localhost case.call_and_validate(verify=False, timeout=timeout) diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py index 882fa0886ce3..0361cd182f27 100644 --- a/tests/entrypoints/openai/test_optional_middleware.py +++ b/tests/entrypoints/openai/test_optional_middleware.py @@ -37,7 +37,7 @@ def server(request: pytest.FixtureRequest): "--enforce-eager", "--max-num-seqs", "2", - *passed_params + *passed_params, ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -73,8 +73,9 @@ async def test_missing_api_token(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_passed_api_token(server: RemoteOpenAIServer): - response = requests.get(server.url_for("v1/models"), - headers={"Authorization": "Bearer test"}) + response = requests.get( + server.url_for("v1/models"), headers={"Authorization": "Bearer test"} + ) assert response.status_code == HTTPStatus.OK @@ -110,7 +111,8 @@ async def test_enable_request_id_header(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_custom_request_id_header(server: RemoteOpenAIServer): - response = requests.get(server.url_for("health"), - headers={"X-Request-Id": "Custom"}) + response = requests.get( + server.url_for("health"), headers={"X-Request-Id": "Custom"} + ) assert "X-Request-Id" in response.headers assert response.headers.get("X-Request-Id") == "Custom" diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/openai/test_pooling.py index 02165ee6d58e..2a10be084b92 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/openai/test_pooling.py @@ -47,11 +47,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): # test single pooling response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -67,11 +63,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): input_tokens = [1, 1, 1, 1, 1] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -89,16 +81,13 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -111,15 +100,15 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.usage.total_tokens == 29 # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -134,18 +123,21 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_pooling(server: RemoteOpenAIServer, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("pooling"), @@ -181,24 +173,22 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, }, ) completions_response.raise_for_status() - completion_poolings = PoolingResponse.model_validate( - completions_response.json()) + completion_poolings = PoolingResponse.model_validate(completions_response.json()) assert chat_poolings.id is not None assert completion_poolings.id is not None assert chat_poolings.created <= completion_poolings.created - assert chat_poolings.model_dump( - exclude={"id", "created"}) == (completion_poolings.model_dump( - exclude={"id", "created"})) + assert chat_poolings.model_dump(exclude={"id", "created"}) == ( + completion_poolings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_pooling(server: RemoteOpenAIServer, - model_name: str): +async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] float_response = requests.post( @@ -211,9 +201,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) float_response.raise_for_status() responses_float = PoolingResponse.model_validate(float_response.json()) - float_data = [ - np.array(d.data).squeeze(-1).tolist() for d in responses_float.data - ] + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] base64_response = requests.post( server.url_for("pooling"), @@ -229,13 +217,15 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, decoded_responses_base64_data = [] for data in responses_base64.data: decoded_responses_base64_data.append( - np.frombuffer(base64.b64decode(data.data), - dtype="float32").tolist()) - - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=decoded_responses_base64_data, - name_0="float32", - name_1="base64") + np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist() + ) + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64", + ) # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -251,10 +241,12 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.array(d.data).squeeze(-1).tolist() for d in responses_default.data ] - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=default_data, - name_0="float32", - name_1="default") + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=default_data, + name_0="float32", + name_1="default", + ) @pytest.mark.asyncio @@ -269,39 +261,46 @@ async def test_invocations(server: RemoteOpenAIServer): "encoding_format": "float", } - completion_response = requests.post(server.url_for("pooling"), - json=request_args) + completion_response = requests.post(server.url_for("pooling"), json=request_args) completion_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.json() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=completion_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=completion_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -312,18 +311,22 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): chat_response = requests.post(server.url_for("pooling"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=chat_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=chat_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="chat", + name_1="invocation", + ) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index ff0730c77032..dcea9571cf7d 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -16,12 +16,12 @@ async def test_empty_prompt(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match="decoder prompt cannot be empty"): - await client.completions.create(model=model_name, - prompt="", - max_tokens=5, - temperature=0.0) + with pytest.raises( + openai.BadRequestError, match="decoder prompt cannot be empty" + ): + await client.completions.create( + model=model_name, prompt="", max_tokens=5, temperature=0.0 + ) @pytest.mark.asyncio @@ -31,12 +31,12 @@ async def test_out_of_vocab_token_ids(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match=re.compile('.*out of vocabulary.*').pattern): - await client.completions.create(model=model_name, - prompt=[999999], - max_tokens=5, - temperature=0.0) + with pytest.raises( + openai.BadRequestError, match=re.compile(".*out of vocabulary.*").pattern + ): + await client.completions.create( + model=model_name, prompt=[999999], max_tokens=5, temperature=0.0 + ) @pytest.mark.asyncio @@ -47,14 +47,13 @@ async def test_reject_multistep_with_guided_decoding(): client = remote_server.get_async_client() with pytest.raises( - openai.BadRequestError, - match=re.compile( - '.*Guided decoding .* multi-step decoding.*').pattern): + openai.BadRequestError, + match=re.compile(".*Guided decoding .* multi-step decoding.*").pattern, + ): await client.completions.create( model=model_name, prompt="Hello", max_tokens=5, temperature=0.0, - extra_body={"response_format": { - "type": "json_object" - }}) + extra_body={"response_format": {"type": "json_object"}}, + ) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe13691..b7d03bbc2634 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -32,15 +32,18 @@ def server(): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -56,16 +59,14 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", "Cross-encoder models are neat" + "The capital of France is Paris.", + "Cross-encoder models are neat", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - "top_n": 2 - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents, "top_n": 2}, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -78,28 +79,26 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): - query = "What is the capital of France?" * 100 documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents}, + ) assert rerank_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - rerank_response.text + assert "Please reduce the length of the input." in rerank_response.text def test_invocations(server: RemoteOpenAIServer): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] request_args = { @@ -108,20 +107,22 @@ def test_invocations(server: RemoteOpenAIServer): "documents": documents, } - rerank_response = requests.post(server.url_for("rerank"), - json=request_args) + rerank_response = requests.post(server.url_for("rerank"), json=request_args) rerank_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() rerank_output = rerank_response.json() invocation_output = invocation_response.json() assert rerank_output.keys() == invocation_output.keys() - for rerank_result, invocations_result in zip(rerank_output["results"], - invocation_output["results"]): + for rerank_result, invocations_result in zip( + rerank_output["results"], invocation_output["results"] + ): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.01 + ) diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index 099062e55c72..8f5a3104e6e0 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -10,11 +10,13 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -from .test_completion import default_server_args # noqa: F401 -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 -from .test_completion import zephyr_pa_files # noqa: F401 -from .test_completion import MODEL_NAME +from .test_completion import ( + MODEL_NAME, + default_server_args, # noqa: F401 + zephyr_lora_added_tokens_files, # noqa: F401 + zephyr_lora_files, # noqa: F401 + zephyr_pa_files, # noqa: F401 +) @pytest.fixture(scope="module") @@ -25,22 +27,19 @@ def server_fixture(request, default_server_args): # noqa: F811 with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: yield (remote_server, True) else: - with RemoteOpenAIServer(MODEL_NAME, - default_server_args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield (remote_server, False) @pytest.mark.asyncio @pytest.mark.parametrize("server_fixture", [True, False], indirect=True) -async def test_completion_return_tokens_as_token_ids_completion( - server_fixture): +async def test_completion_return_tokens_as_token_ids_completion(server_fixture): server, use_server_flag = server_fixture request_args = {} if not use_server_flag: request_args["return_tokens_as_token_ids"] = True async with server.get_async_client() as client: - completion = await client.completions.create( model=MODEL_NAME, # Include Unicode characters to test for dividing a single @@ -51,7 +50,8 @@ async def test_completion_return_tokens_as_token_ids_completion( temperature=0, max_tokens=10, logprobs=1, - extra_body=request_args) + extra_body=request_args, + ) text = completion.choices[0].text token_strs = completion.choices[0].logprobs.tokens @@ -85,22 +85,22 @@ async def test_chat_return_tokens_as_token_ids_completion(server_fixture): # Include Unicode characters to test for dividing a single # character across multiple tokens: 🎉 is [28705, 31862] for the # Zephyr tokenizer - messages=[{ - "role": "system", - "content": "You like to respond in only emojis, like 🎉" - }, { - "role": "user", - "content": "Please write some emojis: 🐱🐶🎉" - }], + messages=[ + { + "role": "system", + "content": "You like to respond in only emojis, like 🎉", + }, + {"role": "user", "content": "Please write some emojis: 🐱🐶🎉"}, + ], temperature=0, max_tokens=8, logprobs=True, - extra_body=request_args) + extra_body=request_args, + ) text = response.choices[0].message.content tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) token_ids = [] for logprob_content in response.choices[0].logprobs.content: - token_ids.append( - int(logprob_content.token.removeprefix("token_id:"))) + token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) assert tokenizer.decode(token_ids, skip_special_tokens=True) == text diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py index 7b4966848b9d..6bcb80878f07 100644 --- a/tests/entrypoints/openai/test_root_path.py +++ b/tests/entrypoints/openai/test_root_path.py @@ -51,26 +51,31 @@ class TestCase(NamedTuple): model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), ], ) -async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, - test_case: TestCase): +async def test_chat_session_root_path_with_api_key( + server: RemoteOpenAIServer, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" ctx = contextlib.nullcontext() if test_case.expected_error is not None: @@ -79,20 +84,16 @@ async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, client = openai.AsyncOpenAI( api_key=test_case.api_key, base_url=server.url_for(*test_case.base_url), - max_retries=0) + max_retries=0, + ) chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], - extra_body={ - "continue_final_message": True, - "add_generation_prompt": False - }) + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], + extra_body={"continue_final_message": True, "add_generation_prompt": False}, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index e23f41e983b0..d31dadf90679 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -35,15 +35,24 @@ def test_empty_file(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write("") input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -53,15 +62,24 @@ def test_empty_file(): def test_completions(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "NousResearch/Meta-Llama-3-8B-Instruct", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -77,30 +95,48 @@ def test_completions_invalid_input(): """ Ensure that we fail when the input doesn't conform to the openai api. """ - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INVALID_INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "NousResearch/Meta-Llama-3-8B-Instruct", + ], + ) proc.communicate() proc.wait() assert proc.returncode != 0, f"{proc=}" def test_embeddings(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_EMBEDDING_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -112,24 +148,26 @@ def test_embeddings(): BatchRequestOutput.model_validate_json(line) -@pytest.mark.parametrize("input_batch", - [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) +@pytest.mark.parametrize("input_batch", [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) def test_score(input_batch): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - "vllm", - "run-batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "BAAI/bge-reranker-v2-m3", - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "BAAI/bge-reranker-v2-m3", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 187542b7bafc..a5381ec1f87b 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -21,14 +21,8 @@ def v1(run_with_both_engines): MODELS = [ - { - "name": "BAAI/bge-reranker-v2-m3", - "is_cross_encoder": True - }, - { - "name": "BAAI/bge-base-en-v1.5", - "is_cross_encoder": False - }, + {"name": "BAAI/bge-reranker-v2-m3", "is_cross_encoder": True}, + {"name": "BAAI/bge-base-en-v1.5", "is_cross_encoder": False}, ] DTYPE = "half" @@ -37,9 +31,7 @@ def run_transformers(hf_model, model, text_pairs): if model["is_cross_encoder"]: return hf_model.predict(text_pairs).tolist() else: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] return [ F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0) for pair in hf_embeddings @@ -63,8 +55,9 @@ def server(model: dict[str, Any]): def runner(model: dict[str, Any], hf_runner): kwargs = { "dtype": DTYPE, - "is_cross_encoder" if model["is_cross_encoder"]\ - else "is_sentence_transformer": True + "is_cross_encoder" + if model["is_cross_encoder"] + else "is_sentence_transformer": True, } with hf_runner(model["name"], **kwargs) as hf_model: @@ -72,21 +65,23 @@ def runner(model: dict[str, Any], hf_runner): class TestModel: - - def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -102,23 +97,26 @@ def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_list_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = [ "What is the capital of the United States?", - "What is the capital of France?" + "What is the capital of France?", ] text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -134,17 +132,20 @@ def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_str( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -160,40 +161,41 @@ def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_score_max_model_len(self, server: RemoteOpenAIServer, - model: dict[str, Any]): - + def test_score_max_model_len( + self, server: RemoteOpenAIServer, model: dict[str, Any] + ): text_1 = "What is the capital of France?" * 20 text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) assert score_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - score_response.text + assert "Please reduce the length of the input." in score_response.text # Test truncation - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - "truncate_prompt_tokens": 101 - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "truncate_prompt_tokens": 101, + }, + ) assert score_response.status_code == 400 - assert "Please, select a smaller truncation size." in \ - score_response.text + assert "Please, select a smaller truncation size." in score_response.text - def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, - Any]): + def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." @@ -203,20 +205,22 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, "text_2": text_2, } - score_response = requests.post(server.url_for("score"), - json=request_args) + score_response = requests.post(server.url_for("score"), json=request_args) score_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() score_output = score_response.json() invocation_output = invocation_response.json() assert score_output.keys() == invocation_output.keys() - for score_data, invocation_data in zip(score_output["data"], - invocation_output["data"]): + for score_data, invocation_data in zip( + score_output["data"], invocation_output["data"] + ): assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( - invocation_data["score"], rel=0.01) + invocation_data["score"], rel=0.01 + ) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8a7892cf6d6a..ade16ad35781 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -13,8 +13,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" @@ -50,7 +49,6 @@ def get_diff_sampling_param(self): @dataclass class MockEngine: - async def get_model_config(self): return MockModelConfig() @@ -60,13 +58,15 @@ async def _async_serving_chat_init(): model_config = await engine.get_model_config() models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) - serving_completion = OpenAIServingChat(engine, - model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_completion = OpenAIServingChat( + engine, + model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) return serving_completion @@ -81,23 +81,24 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=MockModelConfig()) - serving_chat = OpenAIServingChat(mock_engine, - MockModelConfig(), - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig(), + ) + serving_chat = OpenAIServingChat( + mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], guided_decoding_backend="outlines", ) @@ -125,24 +126,25 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) + serving_chat = OpenAIServingChat( + mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) # Test Case 1: No max_tokens specified in request req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], guided_decoding_backend="outlines", ) @@ -180,24 +182,25 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) + serving_chat = OpenAIServingChat( + mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) # Test case 1: No max_tokens specified, defaults to context_window req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], guided_decoding_backend="outlines", ) @@ -225,11 +228,10 @@ async def test_serving_chat_should_set_correct_max_tokens(): @pytest.mark.asyncio async def test_serving_chat_could_load_correct_generation_config(): - mock_model_config = MockModelConfig() mock_model_config.diff_sampling_param = { "temperature": 0.5, - "repetition_penalty": 1.05 + "repetition_penalty": 1.05, } mock_engine = MagicMock(spec=MQLLMEngineClient) @@ -237,23 +239,24 @@ async def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) + serving_chat = OpenAIServingChat( + mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], guided_decoding_backend="outlines", ) @@ -291,24 +294,25 @@ async def test_serving_chat_did_set_correct_cache_salt(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) + serving_chat = OpenAIServingChat( + mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) # Test cache_salt req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) # By default cache_salt in the engine prompt is not set diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 5f334c754a3f..3d21919489f3 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -8,19 +8,20 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorResponse, - LoadLoRAAdapterRequest, - UnloadLoRAAdapterRequest) -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] -LORA_LOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully." LORA_UNLOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' removed successfully.") + "Success: LoRA adapter '{lora_name}' removed successfully." +) async def _async_serving_models_init() -> OpenAIServingModels: @@ -29,11 +30,13 @@ async def _async_serving_models_init() -> OpenAIServingModels: # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 - serving_models = OpenAIServingModels(engine_client=mock_engine_client, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, - lora_modules=None, - prompt_adapters=None) + serving_models = OpenAIServingModels( + engine_client=mock_engine_client, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + lora_modules=None, + prompt_adapters=None, + ) await serving_models.init_static_loras() return serving_models @@ -43,19 +46,18 @@ async def _async_serving_models_init() -> OpenAIServingModels: async def test_serving_model_name(): serving_models = await _async_serving_models_init() assert serving_models.model_name(None) == MODEL_NAME - request = LoRARequest(lora_name="adapter", - lora_path="/path/to/adapter2", - lora_int_id=1) + request = LoRARequest( + lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1 + ) assert serving_models.model_name(request) == request.lora_name @pytest.mark.asyncio async def test_load_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter", - lora_path="/path/to/adapter2") + request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter") assert len(serving_models.lora_requests) == 1 assert "adapter" in serving_models.lora_requests assert serving_models.lora_requests["adapter"].lora_name == "adapter" @@ -74,15 +76,16 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 1 - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" @@ -93,15 +96,15 @@ async def test_load_lora_adapter_duplicate(): @pytest.mark.asyncio async def test_unload_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert len(serving_models.lora_requests) == 1 request = UnloadLoRAAdapterRequest(lora_name="adapter1") response = await serving_models.unload_lora_adapter(request) - assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 0 diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 29a94c852bba..ff46df81d0ff 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -24,16 +24,13 @@ async def test_shutdown_on_engine_failure(): with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: async with remote_server.get_async_client() as client: - - with pytest.raises( - (openai.APIConnectionError, openai.InternalServerError)): + with pytest.raises((openai.APIConnectionError, openai.InternalServerError)): # Asking for lots of prompt logprobs will currently crash the # engine. This may change in the future when that bug is fixed prompt = "Hello " * 4000 await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - extra_body={"prompt_logprobs": 10}) + model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10} + ) # Now the server should shut down return_code = remote_server.proc.wait(timeout=8) diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py index 0dd6af17ef22..e07436f89d2d 100644 --- a/tests/entrypoints/openai/test_sleep.py +++ b/tests/entrypoints/openai/test_sleep.py @@ -20,14 +20,12 @@ def test_sleep_mode(): "--enable-sleep-mode", ] - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={ - "VLLM_SERVER_DEV_MODE": "1", - "CUDA_VISIBLE_DEVICES": "0" - }) as remote_server: - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"}, + ) as remote_server: + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) assert response.status_code == 200 @@ -40,12 +38,12 @@ def test_sleep_mode(): assert response.json().get("is_sleeping") is False # test wake up with tags - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["weights"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["weights"]} + ) assert response.status_code == 200 # is sleeping should be false after waking up any part of the engine @@ -53,8 +51,9 @@ def test_sleep_mode(): assert response.status_code == 200 assert response.json().get("is_sleeping") is True - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["kv_cache"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["kv_cache"]} + ) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 4bf379850365..9b24fdfa5c91 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -11,7 +11,10 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model) + TensorizerConfig, + tensorize_lora_adapter, + tensorize_vllm_model, +) from ...utils import RemoteOpenAIServer @@ -29,21 +32,20 @@ def cleanup(): _cleanup() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tmp_dir(): with tempfile.TemporaryDirectory() as path: yield path -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def model_uri(tmp_dir): yield f"{tmp_dir}/model.tensors" @pytest.fixture(scope="module") def tensorize_model_and_lora(tmp_dir, model_uri): - tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, - lora_dir=tmp_dir) + tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir) args = EngineArgs(model=MODEL_NAME, device="cuda") tensorize_lora_adapter(LORA_PATH, tensorizer_config) @@ -66,8 +68,11 @@ def server(model_uri, tensorize_model_and_lora): ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, - "--enable-lora" + "--load-format", + "tensorizer", + "--served-model-name", + MODEL_NAME, + "--enable-lora", ] model_dir = os.path.dirname(model_uri) @@ -85,10 +90,9 @@ async def client(server): @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): _cleanup() - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -97,4 +101,5 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 0dbbdfbfd24a..307dc0cacfe9 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -8,8 +8,10 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 +from .test_completion import ( + zephyr_lora_added_tokens_files, # noqa: F401 + zephyr_lora_files, # noqa: F401 +) # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -40,10 +42,10 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 @pytest.fixture(scope="module") -def tokenizer_name(model_name: str, - zephyr_lora_added_tokens_files: str): # noqa: F811 - return zephyr_lora_added_tokens_files if ( - model_name == "zephyr-lora2") else model_name +def tokenizer_name(model_name: str, zephyr_lora_added_tokens_files: str): # noqa: F811 + return ( + zephyr_lora_added_tokens_files if (model_name == "zephyr-lora2") else model_name + ) @pytest_asyncio.fixture @@ -63,19 +65,20 @@ async def test_tokenize_completions( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_special in [False, True]: prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt, + }, + ) response.raise_for_status() result = response.json() @@ -96,48 +99,39 @@ async def test_tokenize_chat( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": "user", - "content": "Hi there!" - }, { - "role": "assistant", - "content": "Nice to meet you!" - }, { - "role": "user", - "content": "Can I ask a question? vllm1" - }] + conversation = [ + {"role": "user", "content": "Hi there!"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "Can I ask a question? vllm1"}, + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, continue_final_message=continue_final, conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) - - response = requests.post(server.url_for("tokenize"), - json={ - "add_generation_prompt": - add_generation, - "continue_final_message": - continue_final, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) + tokenize=False, + ) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + }, + ) response.raise_for_status() result = response.json() @@ -158,41 +152,35 @@ async def test_tokenize_chat_with_tools( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": - "user", - "content": - "What's the weather like in Paris today?", - }] - - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string" - } + conversation = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, }, }, - }, - }] + } + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, @@ -201,8 +189,7 @@ async def test_tokenize_chat_with_tools( tools=tools, tokenize=False, ) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) response = requests.post( server.url_for("tokenize"), @@ -235,17 +222,12 @@ async def test_tokenize_with_return_token_strs( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a token_strs test prompt! vllm1" response = requests.post( server.url_for("tokenize"), - json={ - "prompt": prompt, - "model": model_name, - "return_token_strs": True - }, + json={"prompt": prompt, "model": model_name, "return_token_strs": True}, ) response.raise_for_status() @@ -270,17 +252,14 @@ async def test_detokenize( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post(server.url_for("detokenize"), - json={ - "model": model_name, - "tokens": tokens - }) + response = requests.post( + server.url_for("detokenize"), json={"model": model_name, "tokens": tokens} + ) response.raise_for_status() assert response.json() == {"prompt": prompt} @@ -329,14 +308,15 @@ async def test_tokenizer_info_schema(server: RemoteOpenAIServer): } for field, expected_type in field_types.items(): if field in result and result[field] is not None: - assert isinstance( - result[field], - expected_type), (f"{field} should be {expected_type.__name__}") + assert isinstance(result[field], expected_type), ( + f"{field} should be {expected_type.__name__}" + ) @pytest.mark.asyncio async def test_tokenizer_info_added_tokens_structure( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test added_tokens_decoder structure if present.""" response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() @@ -347,25 +327,23 @@ async def test_tokenizer_info_added_tokens_structure( assert isinstance(token_id, str), "Token IDs should be strings" assert isinstance(token_info, dict), "Token info should be a dict" assert "content" in token_info, "Token info should have content" - assert "special" in token_info, ( - "Token info should have special flag") - assert isinstance(token_info["special"], - bool), ("Special flag should be boolean") + assert "special" in token_info, "Token info should have special flag" + assert isinstance(token_info["special"], bool), ( + "Special flag should be boolean" + ) @pytest.mark.asyncio async def test_tokenizer_info_consistency_with_tokenize( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test that tokenizer info is consistent with tokenization endpoint.""" info_response = requests.get(server.url_for("tokenizer_info")) info_response.raise_for_status() info = info_response.json() tokenize_response = requests.post( server.url_for("tokenize"), - json={ - "model": MODEL_NAME, - "prompt": "Hello world!" - }, + json={"model": MODEL_NAME, "prompt": "Hello world!"}, ) tokenize_response.raise_for_status() tokenize_result = tokenize_response.json() @@ -373,7 +351,8 @@ async def test_tokenizer_info_consistency_with_tokenize( tokenize_max_len = tokenize_result.get("max_model_len") if info_max_len and tokenize_max_len: assert info_max_len >= tokenize_max_len, ( - "Info max length should be >= tokenize max length") + "Info max length should be >= tokenize max length" + ) @pytest.mark.asyncio @@ -384,6 +363,5 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): result = response.json() chat_template = result.get("chat_template") if chat_template: - assert isinstance(chat_template, - str), ("Chat template should be a string") - assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file + assert isinstance(chat_template, str), "Chat template should be a string" + assert chat_template.strip(), "Chat template should not be empty" diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index a8e2eb40b157..b5fc3b0f471c 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -18,29 +18,33 @@ from ...utils import RemoteOpenAIServer MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] @pytest.fixture def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_local_path() + path = AudioAsset("mary_had_lamb").get_local_path() with open(str(path), "rb") as f: yield f @pytest.fixture def winning_call(): - path = AudioAsset('winning_call').get_local_path() + path = AudioAsset("winning_call").get_local_path() with open(str(path), "rb") as f: yield f @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]) + "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"] +) async def test_basic_audio(mary_had_lamb, model_name): server_args = ["--enforce-eager"] @@ -55,8 +59,9 @@ async def test_basic_audio(mary_had_lamb, model_name): file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert "Mary had a little lamb," in out @@ -69,10 +74,9 @@ async def test_bad_requests(mary_had_lamb): # invalid language with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=model_name, - file=mary_had_lamb, - language="hh", - temperature=0.0) + await client.audio.transcriptions.create( + model=model_name, file=mary_had_lamb, language="hh", temperature=0.0 + ) @pytest.mark.asyncio @@ -90,7 +94,7 @@ async def test_long_audio_request(mary_had_lamb, model_name): repeated_audio = np.tile(audio, 10) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() @@ -99,8 +103,9 @@ async def test_long_audio_request(mary_had_lamb, model_name): file=buffer, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] counts = out.count("Mary had a little lamb") assert counts == 10, counts @@ -112,10 +117,9 @@ async def test_non_asr_model(winning_call): server_args = ["--enforce-eager"] with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - res = await client.audio.transcriptions.create(model=model_name, - file=winning_call, - language="en", - temperature=0.0) + res = await client.audio.transcriptions.create( + model=model_name, file=winning_call, language="en", temperature=0.0 + ) assert res.code == 400 and not res.text assert res.message == "The model does not support Transcriptions API" @@ -129,10 +133,8 @@ async def test_completion_endpoints(): client = remote_server.get_async_client() res = await client.chat.completions.create( model=model_name, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }]) + messages=[{"role": "system", "content": "You are a helpful assistant."}], + ) assert res.code == 400 assert res.message == "The model does not support Chat Completions API" @@ -153,13 +155,14 @@ async def test_streaming_response(winning_call): file=winning_call, response_format="json", language="en", - temperature=0.0) + temperature=0.0, + ) # Unfortunately this only works when the openai client is patched # to use streaming mode, not exposed in the transcription api. original_post = AsyncAPIClient.post async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True + kwargs["stream"] = True return await original_post(*args, **kwargs) with patch.object(AsyncAPIClient, "post", new=post_with_stream): @@ -170,11 +173,12 @@ async def post_with_stream(*args, **kwargs): language="en", temperature=0.0, extra_body=dict(stream=True), - timeout=30) + timeout=30, + ) # Reconstruct from chunks and validate async for chunk in res: # just a chunk - text = chunk.choices[0]['delta']['content'] + text = chunk.choices[0]["delta"]["content"] transcription += text assert transcription == res_no_stream.text @@ -188,7 +192,7 @@ async def test_stream_options(winning_call): original_post = AsyncAPIClient.post async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True + kwargs["stream"] = True return await original_post(*args, **kwargs) with patch.object(AsyncAPIClient, "post", new=post_with_stream): @@ -198,10 +202,13 @@ async def post_with_stream(*args, **kwargs): file=winning_call, language="en", temperature=0.0, - extra_body=dict(stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True), - timeout=30) + extra_body=dict( + stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True, + ), + timeout=30, + ) final = False continuous = True async for chunk in res: @@ -209,7 +216,7 @@ async def post_with_stream(*args, **kwargs): # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') + continuous = continuous and hasattr(chunk, "usage") assert final and continuous @@ -217,7 +224,7 @@ async def post_with_stream(*args, **kwargs): async def test_sampling_params(mary_had_lamb): """ Compare sampling with params and greedy sampling to assert results - are different when extreme sampling parameters values are picked. + are different when extreme sampling parameters values are picked. """ model_name = "openai/whisper-small" server_args = ["--enforce-eager"] @@ -228,20 +235,24 @@ async def test_sampling_params(mary_had_lamb): file=mary_had_lamb, language="en", temperature=0.8, - extra_body=dict(seed=42, - repetition_penalty=1.9, - top_k=12, - top_p=0.4, - min_p=0.5, - frequency_penalty=1.8, - presence_penalty=2.0)) + extra_body=dict( + seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0, + ), + ) greedy_transcription = await client.audio.transcriptions.create( model=model_name, file=mary_had_lamb, language="en", temperature=0.0, - extra_body=dict(seed=42)) + extra_body=dict(seed=42), + ) assert greedy_transcription.text != transcription.text @@ -252,7 +263,7 @@ async def test_audio_prompt(mary_had_lamb): server_args = ["--enforce-eager"] prompt = "This is a speech, recorded in a phonograph." with RemoteOpenAIServer(model_name, server_args) as remote_server: - #Prompts should not omit the part of original prompt while transcribing. + # Prompts should not omit the part of original prompt while transcribing. prefix = "The first words I spoke in the original phonograph" client = remote_server.get_async_client() transcription = await client.audio.transcriptions.create( @@ -260,8 +271,9 @@ async def test_audio_prompt(mary_had_lamb): file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert prefix in out transcription_wprompt = await client.audio.transcriptions.create( model=model_name, @@ -269,6 +281,7 @@ async def test_audio_prompt(mary_had_lamb): language="en", response_format="text", prompt=prompt, - temperature=0.0) - out_prompt = json.loads(transcription_wprompt)['text'] + temperature=0.0, + ) + out_prompt = json.loads(transcription_wprompt)["text"] assert prefix in out_prompt diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index 79e769e3a1aa..8f15d9d43e92 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io + # imports for guided decoding tests import json from unittest.mock import patch @@ -20,7 +21,7 @@ @pytest.fixture def foscolo(): # Test translation it->en - path = AudioAsset('azacinto_foscolo').get_local_path() + path = AudioAsset("azacinto_foscolo").get_local_path() with open(str(path), "rb") as f: yield f @@ -38,8 +39,9 @@ async def test_basic_audio(foscolo): response_format="text", # TODO remove once language detection is implemented extra_body=dict(language="it"), - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert "greek sea" in out @@ -57,8 +59,9 @@ async def test_audio_prompt(foscolo): prompt=prompt, extra_body=dict(language="it"), response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert "Nor will I ever touch the sacred" not in out assert prompt not in out @@ -70,9 +73,9 @@ async def test_non_asr_model(foscolo): server_args = ["--enforce-eager"] with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0) + res = await client.audio.translations.create( + model=model_name, file=foscolo, temperature=0.0 + ) assert res.code == 400 and not res.text assert res.message == "The model does not support Translations API" @@ -89,27 +92,28 @@ async def test_streaming_response(foscolo): file=foscolo, response_format="json", extra_body=dict(language="it"), - temperature=0.0) + temperature=0.0, + ) # Unfortunately this only works when the openai client is patched # to use streaming mode, not exposed in the translation api. original_post = AsyncAPIClient.post async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True + kwargs["stream"] = True return await original_post(*args, **kwargs) with patch.object(AsyncAPIClient, "post", new=post_with_stream): client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0, - extra_body=dict( - stream=True, - language="it")) + res = await client.audio.translations.create( + model=model_name, + file=foscolo, + temperature=0.0, + extra_body=dict(stream=True, language="it"), + ) # Reconstruct from chunks and validate async for chunk in res: # just a chunk - text = chunk.choices[0]['delta']['content'] + text = chunk.choices[0]["delta"]["content"] translation += text assert translation == res_no_stream.text @@ -123,7 +127,7 @@ async def test_stream_options(foscolo): original_post = AsyncAPIClient.post async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True + kwargs["stream"] = True return await original_post(*args, **kwargs) with patch.object(AsyncAPIClient, "post", new=post_with_stream): @@ -132,10 +136,13 @@ async def post_with_stream(*args, **kwargs): model=model_name, file=foscolo, temperature=0.0, - extra_body=dict(language="it", - stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True)) + extra_body=dict( + language="it", + stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True, + ), + ) final = False continuous = True async for chunk in res: @@ -143,7 +150,7 @@ async def post_with_stream(*args, **kwargs): # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') + continuous = continuous and hasattr(chunk, "usage") assert final and continuous @@ -157,7 +164,7 @@ async def test_long_audio_request(foscolo): repeated_audio = np.tile(audio, 2) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() @@ -166,6 +173,7 @@ async def test_long_audio_request(foscolo): file=buffer, extra_body=dict(language="it"), response_format="text", - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index b33a26af65b3..774315041d07 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -54,12 +54,10 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == truncation_size @@ -70,15 +68,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } with pytest.raises(openai.BadRequestError) as err: - err = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + err = await client.post(path="embeddings", cast_to=object, body={**kwargs}) - assert str(err) == f"""openai.BadRequestError: + assert ( + str(err) + == f"""openai.BadRequestError: Error code: 400 - {{'object': 'error', 'message': 'truncate_prompt_tokens value ({truncation_size}) @@ -86,6 +84,7 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): Please, select a smaller truncation size.', 'type': 'BadRequestError', 'param': None, 'code': 400}}""" + ) @pytest.mark.asyncio @@ -94,11 +93,9 @@ async def test_max_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == max_model_len diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index b68e08556ee9..825dbc7d2e48 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -58,24 +58,18 @@ def base64_encoded_video() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -84,13 +78,15 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -112,54 +108,44 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": video_url - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_error_on_invalid_video_url_type( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": video_url}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # video_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video_beamsearch( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -168,36 +154,38 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -206,13 +194,15 @@ async def test_single_chat_session_video_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -236,58 +226,54 @@ async def test_single_chat_session_video_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_chat_streaming_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_chat_streaming_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -327,27 +313,23 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "video_urls", - [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]) -async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str, - video_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "video_url", - "video_url": { - "url": video_url - } - } for video_url in video_urls), - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + "video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))] +) +async def test_multi_video_input( + client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "video_url", "video_url": {"url": video_url}} + for video_url in video_urls + ), + {"type": "text", "text": "What's in this video?"}, + ], + } + ] if len(video_urls) > MAXIMUM_VIDEOS: with pytest.raises(openai.BadRequestError): # test multi-video input diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index b6f1d64803e5..e9984a38d068 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -79,19 +79,22 @@ def base64_encoded_image() -> dict[str, str]: def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|>\n" - messages = [{ - "role": "user", - "content": f"{placeholder}{content}", - }] + messages = [ + { + "role": "user", + "content": f"{placeholder}{content}", + } + ] images = [Image.open(requests.get(image_url, stream=True).raw)] prompt = processor.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True) + messages, tokenize=False, add_generation_prompt=True + ) inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] @@ -100,25 +103,19 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_single_chat_session_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): +async def test_single_chat_session_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -128,17 +125,18 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -160,55 +158,45 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): +async def test_error_on_invalid_image_url_type( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": image_url - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": image_url}, + {"type": "text", "text": content_text}, + ], + } + ] # image_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +async def test_single_chat_session_image_beamsearch( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -217,37 +205,39 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_single_chat_session_image_base64encoded( - client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: dict[str, str]): - + client: openai.AsyncOpenAI, + model_name: str, + image_url: str, + base64_encoded_image: dict[str, str], +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + }, + }, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -257,17 +247,18 @@ async def test_single_chat_session_image_base64encoded( max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -291,36 +282,37 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS)))) async def test_single_chat_session_image_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, image_idx: int, - base64_encoded_image: dict[str, str]): + client: openai.AsyncOpenAI, + model_name: str, + image_idx: int, + base64_encoded_image: dict[str, str], +): # NOTE: This test also validates that we pass MM data through beam search image_url = TEST_IMAGE_URLS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + }, + }, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, temperature=0.0, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 for actual, expected_str in zip(chat_completion.choices, expected_res): assert actual.message.content == expected_str @@ -329,24 +321,18 @@ async def test_single_chat_session_image_base64encoded_beamsearch( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_chat_streaming_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +async def test_chat_streaming_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -386,27 +372,23 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) -async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + "image_urls", [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))] +) +async def test_multi_image_input( + client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index fe982e286ae4..86c9ae6c6b93 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -58,9 +58,9 @@ def base64_encoded_image() -> dict[str, str]: def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" @@ -72,39 +72,28 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, - image_url: str): +async def test_image_embedding( + server: RemoteOpenAIServer, model_name: str, image_url: str +): content_text = "Represent the given image." - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] response = requests.post( server.url_for("v1/embeddings"), - json={ - "model": model_name, - "messages": messages, - "encoding_format": "float" - }, + json={"model": model_name, "messages": messages, "encoding_format": "float"}, ) response.raise_for_status() embeddings = EmbeddingResponse.model_validate(response.json()) - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert embeddings.id is not None assert len(embeddings.data) == 1 diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bd8e06513e13..bdd5344652c4 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -8,15 +8,18 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): - return ToolCall(type="function", - function=FunctionCall(name=name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=name, arguments=json.dumps(arguments)), + ) # TODO: add reason prefix and suffix. @@ -29,70 +32,68 @@ def make_tool_call(name, arguments): ("How can I help you today?", [], "How can I help you today?"), # Single tool call, no content ( - "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]", #noqa: E501 + '[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]', # noqa: E501 [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) ], - None), + None, + ), # Multiple tool calls ( - "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]", #noqa: E501 + '[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]', # noqa: E501 [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }), make_tool_call( - "register_user", { + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ), + make_tool_call( + "register_user", + { "name": "John Doe", "age": 37, - "address": { - "city": "San Francisco", - "state": "CA" - }, + "address": {"city": "San Francisco", "state": "CA"}, "role": None, "passed_test": True, - "aliases": ["John", "Johnny"] - }) + "aliases": ["John", "Johnny"], + }, + ), ], - None), + None, + ), # Content before tool call ( - "I will call the tool now. [{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]", #noqa: E501 + 'I will call the tool now. [{"name": "get_weather", "arguments": {"city": "Boston"}}]', # noqa: E501 [make_tool_call("get_weather", {"city": "Boston"})], - "I will call the tool now. "), + "I will call the tool now. ", + ), # Content after tool call (should be stripped) ( - "[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]\nThank you!", #noqa: E501 + '[{"name": "get_weather", "arguments": {"city": "Seattle"}}]\nThank you!', # noqa: E501 [make_tool_call("get_weather", {"city": "Seattle"})], - None), + None, + ), ( - "[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]", + '[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]', [ make_tool_call( - "complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) ], None, ), - ]) -def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, - expected_content): + ], +) +def test_hunyuan_a13b_tool_parser_extract( + model_output, expected_tool_calls, expected_content +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=False) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=False + ) # align the random id. for idx in range(len(tool_calls)): @@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, # Streaming test: simulate incremental output -@pytest.mark.parametrize("model_deltas,expected_tool_calls", [ - ([ - "[{\"name\": \"get_weather\", ", - "\"arguments\": {\"city\": \"San Francisco\", ", - "\"metric\": \"celsius\"}}]", "" - ], [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) - ]), - ([ - "[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - ([ - "", "[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "", "\n" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - pytest.param([ - "[{\"name\": \"complex_tool\",", " \"arguments\": ", - " {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}", - "]" - ], [ - make_tool_call("complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) +@pytest.mark.parametrize( + "model_deltas,expected_tool_calls", + [ + ( + [ + '[{"name": "get_weather", ', + '"arguments": {"city": "San Francisco", ', + '"metric": "celsius"}}]', + "", + ], + [ + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) + ], + ), + ( + [ + '[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + ( + [ + "", + '[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "", + "\n", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + pytest.param( + [ + '[{"name": "complex_tool",', + ' "arguments": ', + ' {"level1": {"level2": ', + '{"level3": {"value": 123}}}}}', + "]", + ], + [ + make_tool_call( + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) + ], + marks=pytest.mark.xfail( + reason="stream parsing not support nested json yet." + ), + ), ], - marks=pytest.mark.xfail( - reason="stream parsing not support nested json yet.")), -]) +) def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) reconstructor = run_tool_extraction_streaming( - tool_parser, model_deltas, assert_one_tool_per_delta=False) + tool_parser, model_deltas, assert_one_tool_per_delta=False + ) # align the random id. for idx in range(len(reconstructor.tool_calls)): diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8c86b4889e15..94277980f229 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -6,7 +6,9 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -16,12 +18,14 @@ name="get_weather", arguments='{"city": "LA", "metric": "C"}', ) -MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " - "age=9, " - "address={'city': 'LA', 'state': 'CA'}, " - "role=None, " - "passed_test=True, " - "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_OUTPUT = ( + "[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "Doe", ' @@ -34,7 +38,7 @@ PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -47,25 +51,28 @@ arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', ) PYTHON_TAG_FUNCTION_OUTPUT = ( - "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>" +) @pytest.mark.parametrize("streaming", [True, False]) def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 @@ -75,98 +82,139 @@ def test_no_tool_call(streaming: bool): test_str += "[get_weather(city='LA', metric='C')," test_str += "register_user(name='Doe', age=9)]" TEST_CASES = [ - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming" + ), + pytest.param( + True, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), pytest.param( True, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_streaming"), + id="parallel_calls_streaming", + ), pytest.param( False, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_nonstreaming", + ), + pytest.param( + True, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming", + ), + pytest.param( + False, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming", + ), + pytest.param( + True, + test_str, + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_streaming", + ), + pytest.param( + False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_nonstreaming"), - pytest.param(True, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_streaming"), - pytest.param(False, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_nonstreaming"), - pytest.param(True, - test_str, [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_streaming"), - pytest.param(False, - "<|python_start|>[get_weather(city='LA', metric='C'), " + - "register_user(name='Doe', age=9)]", [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_nonstreaming"), + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert len(tool_calls) == len(expected_tool_calls) for actual, expected in zip(tool_calls, expected_tool_calls): @@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output_deltas = [ "<|python_start|>[get_weather(city='LA', metric='C'), " "get_weather(), " @@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index d83137472598..ccd6abbac4c9 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -6,7 +6,9 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -22,7 +24,8 @@ "address={'city': 'San Francisco', 'state': 'CA'}, " "role=None, " "passed_test=True, " - "aliases=['John', 'Johnny'])") + "aliases=['John', 'Johnny'])" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "John Doe", ' @@ -35,7 +38,7 @@ PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -48,7 +51,8 @@ arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')") + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', @@ -59,80 +63,118 @@ def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 TEST_CASES = [ - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_nonstreaming"), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming", + ), + pytest.param( + True, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content is None assert len(tool_calls) == len(expected_tool_calls) @@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output_deltas = [ "[get_weather(city='San", " Francisco', metric='celsius'), " @@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index e1b41f45f554..cfa4d3584e70 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -4,15 +4,17 @@ from collections.abc import Iterable from typing import Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import ToolParser class StreamingToolReconstructor: - def __init__(self, assert_one_tool_per_delta: bool = True): self.tool_calls: list[ToolCall] = [] self.other_content: str = "" @@ -23,49 +25,60 @@ def append_delta(self, delta: DeltaMessage): self.other_content += delta.content else: assert delta.tool_calls, ( - "Streaming results should have either content or tool calls " - "(or both)") + "Streaming results should have either content or tool calls (or both)" + ) if self._assert_one_tool_per_delta: # Note: This isn't strictly required by the API and may not be # possible to adhere to depending on the token space and number of # tokens per streamed response from the model, but it is required # by tool_use tests, so we enforce it here by default also. assert len(delta.tool_calls) < 2, ( - "Streaming should include only one tool call per update.") + "Streaming should include only one tool call per update." + ) for call_delta in delta.tool_calls: assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " - f"{call_delta.type}") - current_tool_call = self.tool_calls[ - call_delta.index] if call_delta.index < len( - self.tool_calls) else None + f"{call_delta.type}" + ) + current_tool_call = ( + self.tool_calls[call_delta.index] + if call_delta.index < len(self.tool_calls) + else None + ) if current_tool_call: - assert (not call_delta.function.name), ( + assert not call_delta.function.name, ( "Streaming tool calls should emit the full function name " - f"exactly once. Got {call_delta.function.name}") - assert (not call_delta.id), ( + f"exactly once. Got {call_delta.function.name}" + ) + assert not call_delta.id, ( "Streaming tool calls must emit function id only once. Got " - f"{call_delta.id}") - assert (call_delta.index == len(self.tool_calls) - 1), ( + f"{call_delta.id}" + ) + assert call_delta.index == len(self.tool_calls) - 1, ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls) - 1}") - current_tool_call.function.arguments += ( - call_delta.function.arguments) + f"expected {len(self.tool_calls) - 1}" + ) + current_tool_call.function.arguments += call_delta.function.arguments else: assert call_delta.id is not None, ( - "Streaming tool calls must have an id on first appearance") + "Streaming tool calls must have an id on first appearance" + ) assert call_delta.function.name is not None, ( - "Streaming tool calls must have a function name on first " - "appearance") + "Streaming tool calls must have a function name on first appearance" + ) assert call_delta.index == len(self.tool_calls), ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls)}") + f"expected {len(self.tool_calls)}" + ) self.tool_calls.append( - ToolCall(id=call_delta.id, - function=FunctionCall( - name=call_delta.function.name, - arguments=call_delta.function.arguments - or ""))) + ToolCall( + id=call_delta.id, + function=FunctionCall( + name=call_delta.function.name, + arguments=call_delta.function.arguments or "", + ), + ) + ) def run_tool_extraction( @@ -80,11 +93,11 @@ def run_tool_extraction( tool_parser, model_output, request, - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta, + ) return reconstructor.other_content or None, reconstructor.tool_calls else: - extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, - request) + extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request) assert extracted.tools_called == bool(extracted.tool_calls) return extracted.content, extracted.tool_calls @@ -92,7 +105,7 @@ def run_tool_extraction( def run_tool_extraction_nonstreaming( tool_parser: ToolParser, model_output: str, - request: Union[ChatCompletionRequest, None] = None + request: Union[ChatCompletionRequest, None] = None, ) -> ExtractedToolCallInformation: request = request or ChatCompletionRequest(messages=[], model="test-model") return tool_parser.extract_tool_calls(model_output, request) @@ -106,7 +119,8 @@ def run_tool_extraction_streaming( ) -> StreamingToolReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingToolReconstructor( - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta + ) previous_text = "" previous_tokens: list[int] = [] for delta in model_deltas: @@ -118,8 +132,14 @@ def run_tool_extraction_streaming( current_text = previous_text + delta current_tokens = previous_tokens + token_delta delta_message = tool_parser.extract_tool_calls_streaming( - previous_text, current_text, delta, previous_tokens, - current_tokens, token_delta, request) + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + request, + ) if delta_message is not None: reconstructor.append_delta(delta_message) previous_text = current_text diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index e4af60a78265..921bfb563f02 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -10,8 +10,7 @@ import pytest -from vllm.v1.utils import (APIServerProcessManager, - wait_for_completion_or_failure) +from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure # Global variables to control worker behavior WORKER_RUNTIME_SECONDS = 0.5 @@ -30,26 +29,22 @@ def api_server_args(): """Fixture to provide arguments for APIServerProcessManager.""" sock = socket.socket() return { - "target_server_fn": - mock_run_api_server_worker, - "listen_address": - "localhost:8000", - "sock": - sock, - "args": - "test_args", # Simple string to avoid pickling issues - "num_servers": - 3, + "target_server_fn": mock_run_api_server_worker, + "listen_address": "localhost:8000", + "sock": sock, + "args": "test_args", # Simple string to avoid pickling issues + "num_servers": 3, "input_addresses": [ - "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", - "tcp://127.0.0.1:5003" + "tcp://127.0.0.1:5001", + "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003", ], "output_addresses": [ - "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", - "tcp://127.0.0.1:6003" + "tcp://127.0.0.1:6001", + "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003", ], - "stats_update_address": - "tcp://127.0.0.1:7000", + "stats_update_address": "tcp://127.0.0.1:7000", } @@ -95,8 +90,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): assert not proc.is_alive() -@patch("vllm.entrypoints.cli.serve.run_api_server_worker", - mock_run_api_server_worker) +@patch("vllm.entrypoints.cli.serve.run_api_server_worker", mock_run_api_server_worker) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" global WORKER_RUNTIME_SECONDS @@ -118,8 +112,7 @@ def run_with_exception_capture(): result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Let all processes run for a short time @@ -174,8 +167,7 @@ def test_normal_completion(api_server_args): # Verify all processes have terminated for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"Process {i} still alive after terminate()" + assert not proc.is_alive(), f"Process {i} still alive after terminate()" # Now call wait_for_completion_or_failure # since all processes have already @@ -198,13 +190,13 @@ def test_external_process_monitoring(api_server_args): # Create and start the external process # (simulates local_engine_manager or coordinator) spawn_context = multiprocessing.get_context("spawn") - external_proc = spawn_context.Process(target=mock_run_api_server_worker, - name="MockExternalProcess") + external_proc = spawn_context.Process( + target=mock_run_api_server_worker, name="MockExternalProcess" + ) external_proc.start() # Create the class to simulate a coordinator class MockCoordinator: - def __init__(self, proc): self.proc = proc @@ -228,14 +220,14 @@ def close(self): def run_with_exception_capture(): try: - wait_for_completion_or_failure(api_server_manager=manager, - coordinator=mock_coordinator) + wait_for_completion_or_failure( + api_server_manager=manager, coordinator=mock_coordinator + ) except Exception as e: result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Terminate the external process to trigger a failure @@ -246,21 +238,23 @@ def run_with_exception_capture(): wait_thread.join(timeout=1.0) # The wait thread should have completed - assert not wait_thread.is_alive( - ), "wait_for_completion_or_failure thread still running" + assert not wait_thread.is_alive(), ( + "wait_for_completion_or_failure thread still running" + ) # Verify that an exception was raised with appropriate error message assert result["exception"] is not None, "No exception was raised" error_message = str(result["exception"]) - assert "died with exit code" in error_message, \ + assert "died with exit code" in error_message, ( f"Unexpected error message: {error_message}" - assert "MockExternalProcess" in error_message, \ + ) + assert "MockExternalProcess" in error_message, ( f"Error doesn't mention external process: {error_message}" + ) # Verify that all API server processes were terminated as a result for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"API server process {i} was not terminated" + assert not proc.is_alive(), f"API server process {i} was not terminated" finally: # Clean up diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index e321ca70001d..e189d9585f6c 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -11,15 +11,21 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, - parse_chat_messages, - parse_chat_messages_futures, - resolve_chat_template_content_format, - resolve_hf_chat_template) +from vllm.entrypoints.chat_utils import ( + _try_extract_ast, + load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format, + resolve_hf_chat_template, +) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, - encode_video_base64) +from vllm.multimodal.utils import ( + encode_audio_base64, + encode_image_base64, + encode_video_base64, +) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from ..models.registry import HF_EXAMPLE_MODELS @@ -41,31 +47,35 @@ @pytest.fixture(scope="function") def phi3v_model_config(): - return ModelConfig(PHI3V_MODEL_ID, - task="generate", - tokenizer=PHI3V_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="auto", - seed=0, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + task="generate", + tokenizer=PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="function") def phi3v_model_config_mm_interleaved(): - return ModelConfig(PHI3V_MODEL_ID, - task="generate", - tokenizer=PHI3V_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="auto", - seed=0, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + task="generate", + tokenizer=PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -80,18 +90,20 @@ def phi3v_tokenizer(): @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): - return ModelConfig(QWEN25OMNI_MODEL_ID, - task="generate", - tokenizer=QWEN25OMNI_MODEL_ID, - tokenizer_mode="auto", - dtype="auto", - seed=0, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - "audio": 1, - "video": 1, - }) + return ModelConfig( + QWEN25OMNI_MODEL_ID, + task="generate", + tokenizer=QWEN25OMNI_MODEL_ID, + tokenizer_mode="auto", + dtype="auto", + seed=0, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }, + ) @pytest.fixture(scope="module") @@ -106,16 +118,18 @@ def qwen25omni_tokenizer(): @pytest.fixture(scope="module") def mllama_model_config(): - return ModelConfig(MLLAMA_MODEL_ID, - task="generate", - tokenizer=MLLAMA_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="auto", - seed=0, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MLLAMA_MODEL_ID, + task="generate", + tokenizer=MLLAMA_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -130,16 +144,18 @@ def mllama_tokenizer(): @pytest.fixture(scope="function") def mistral_model_config(): - return ModelConfig(MISTRAL_MODEL_ID, - task="generate", - tokenizer=MISTRAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="auto", - seed=0, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MISTRAL_MODEL_ID, + task="generate", + tokenizer=MISTRAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -154,21 +170,21 @@ def mistral_tokenizer(): @pytest.fixture(scope="module") def image_url(): - image = ImageAsset('cherry_blossom') + image = ImageAsset("cherry_blossom") base64 = encode_image_base64(image.pil_image) return f"data:image/jpeg;base64,{base64}" @pytest.fixture(scope="module") def video_url(): - video = VideoAsset('baby_reading', 1) + video = VideoAsset("baby_reading", 1) base64 = encode_video_base64(video.np_ndarrays) return f"data:video/jpeg;base64,{base64}" @pytest.fixture(scope="module") def audio_url(): - audio = AudioAsset('mary_had_lamb') + audio = AudioAsset("mary_had_lamb") base64 = encode_audio_base64(*audio.audio_and_sample_rate) return f"data:audio/ogg;base64,{base64}" @@ -209,28 +225,23 @@ def test_parse_chat_messages_single_image( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1) @@ -240,58 +251,33 @@ def test_parse_chat_messages_empty_system( ): # Test string format conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], + [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ], mistral_model_config, mistral_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": "Who are you?" - }] + assert conversation == [ + {"role": "system", "content": ""}, + {"role": "user", "content": "Who are you?"}, + ] # Test openai format conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], + [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ], mistral_model_config, mistral_tokenizer, content_format="openai", ) - assert conversation == [{ - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }] + assert conversation == [ + {"role": "system", "content": [{"type": "text", "text": ""}]}, + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ] @pytest.mark.asyncio @@ -301,28 +287,23 @@ async def test_parse_chat_messages_single_image_async( image_url, ): conversation, mm_future = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(await mm_future, 1) @@ -332,33 +313,27 @@ def test_parse_chat_messages_multiple_images( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?"} + ] _assert_mm_data_is_image_input(mm_data, 2) @@ -369,33 +344,27 @@ async def test_parse_chat_messages_multiple_images_async( image_url, ): conversation, mm_future = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?"} + ] _assert_mm_data_is_image_input(await mm_future, 2) @@ -405,36 +374,29 @@ def test_parse_chat_messages_placeholder_already_in_prompt( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to <|image_2|>?", + }, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] + assert conversation == [ + { + "role": "user", + "content": "What's in <|image_1|> and how does it compare to <|image_2|>?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) @@ -444,42 +406,31 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 - } - ] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501 + }, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?" - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " + "other one?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) @@ -489,52 +440,32 @@ def test_parse_chat_messages_multiple_images_across_messages( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about this one?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What about this one?"}, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) assert conversation == [ - { - "role": "user", - "content": "<|image_1|>\nWhat's in this image?" - }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "<|image_2|>\nWhat about this one?" - }, + {"role": "user", "content": "<|image_1|>\nWhat's in this image?"}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "<|image_2|>\nWhat about this one?"}, ] _assert_mm_data_is_image_input(mm_data, 2) @@ -544,46 +475,23 @@ def test_parse_chat_messages_context_text_format( phi3v_tokenizer, ): conversation, mm_data = parse_chat_messages( - [{ - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What about this one?" - }], + [ + { + "role": "user", + "content": [{"type": "text", "text": "What's in this text?"}], + }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What about this one?"}, + ], phi3v_model_config, phi3v_tokenizer, content_format="openai", ) assert conversation == [ - { - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, - { - "role": "assistant", - "content": [{ - "type": "text", - "text": "Some stuff." - }] - }, - { - "role": "user", - "content": [{ - "type": "text", - "text": "What about this one?" - }] - }, + {"role": "user", "content": [{"type": "text", "text": "What's in this text?"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Some stuff."}]}, + {"role": "user", "content": [{"type": "text", "text": "What about this one?"}]}, ] @@ -594,36 +502,23 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ): with warnings.catch_warnings(): warnings.filterwarnings( - "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + "ignore", message="coroutine 'async_get_and_parse_image' was never awaited" + ) with pytest.raises( - ValueError, - match="At most 2 image\\(s\\) may be provided in one request\\." + ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -637,46 +532,30 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ): with warnings.catch_warnings(): warnings.filterwarnings( - "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + "ignore", message="coroutine 'async_get_and_parse_image' was never awaited" + ) with pytest.raises( - ValueError, - match="At most 2 image\\(s\\) may be provided in one request\\." + ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What about these two?"}, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -689,28 +568,24 @@ def test_parse_chat_messages_multiple_images_uncommon_input( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - "What's in these images?", { - "image_url": image_url - }, { - "image_url": image_url - } - ] - }], + [ + { + "role": "user", + "content": [ + "What's in these images?", + {"image_url": image_url}, + {"image_url": image_url}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?"} + ] _assert_mm_data_is_image_input(mm_data, 2) @@ -720,42 +595,30 @@ def test_parse_chat_messages_multiple_images_interleave( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "I need you to compare this image"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) @@ -766,42 +629,30 @@ async def test_parse_chat_messages_multiple_images_interleave_async( image_url, ): conversation, mm_data = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "I need you to compare this image"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(await mm_data, 2) @@ -811,135 +662,84 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( image_url, ): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Be accurate." - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Be accurate."}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|image_1|>\nBe accurate." - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What's on this image?\n<|image_2|>" - }] + assert conversation == [ + {"role": "user", "content": "What's on this image?\n<|image_1|>\nBe accurate."}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What's on this image?\n<|image_2|>"}, + ] _assert_mm_data_is_image_input(mm_data, 2) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( - qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, - image_url, video_url, audio_url): + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "And what's in the video?" - }, { - "type": "video_url", - "video_url": { - "url": video_url - } - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Now listen to this audio"}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "And what's in the video?"}, + {"type": "video_url", "video_url": {"url": video_url}}, + ], + }, + ], qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" - }] + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) @@ -950,35 +750,25 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( image_url, ): with pytest.raises( - ValueError, - match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items."): + ValueError, + match=r"Found more '<|image_1|>' placeholders in input prompt " + "than actual multimodal data items.", + ): parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }, - ] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + }, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", @@ -993,31 +783,29 @@ def test_mllama_single_image( ): """Ensures that a single image is parsed correctly mllama.""" conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "The content of this image is:"}, + {"image_url": image_url}, + ], + } + ], mllama_model_config, mllama_tokenizer, content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 1) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - 'type': 'image' - }] - }] + assert conversation == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "The content of this image is:"}, + {"type": "image"}, + ], + } + ] def test_mllama_interleaved_images( @@ -1027,46 +815,33 @@ def test_mllama_interleaved_images( ): """Ensures that multiple image are parsed as interleaved dicts.""" conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "The content of the first image is:"}, + {"image_url": image_url}, + {"type": "text", "text": "The content of the second image is:"}, + {"image_url": image_url}, + ], + } + ], mllama_model_config, mllama_tokenizer, content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 2) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of the first image is:' - }, { - 'type': 'image' - }, { - 'type': 'text', - 'text': 'The content of the second image is:' - }, { - 'type': 'image' - }] - }] + assert conversation == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "The content of the first image is:"}, + {"type": "image"}, + {"type": "text", "text": "The content of the second image is:"}, + {"type": "image"}, + ], + } + ] @pytest.mark.parametrize("model", [MLLAMA_MODEL_ID]) @@ -1076,39 +851,33 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): def get_conversation(is_hf: bool): img_part = {"type": "image_url", "image_url": {"url": image_url}} if is_hf: - img_part = {'type': 'image'} - return [{ - 'role': - 'user', - 'content': [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'What animal is in the first image?' - }, - ] - }] + img_part = {"type": "image"} + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": "The content of the first image is:"}, + img_part, + {"type": "text", "text": "The content of the second image is:"}, + img_part, + {"type": "text", "text": "What animal is in the first image?"}, + ], + } + ] # Build a config for the model - model_config = ModelConfig(model, - task="generate", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="auto", - seed=0, - limit_mm_per_prompt={ - "image": 2, - }) + model_config = ModelConfig( + model, + task="generate", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -1154,7 +923,8 @@ def get_conversation(is_hf: bool): [ QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ]) + ], +) @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" @@ -1179,14 +949,20 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) tokenizer = tokenizer_group.tokenizer - tools = [{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] if use_tools else None + tools = ( + [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + if use_tools + else None + ) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py index 33ad2cfd3a33..b56fbd9fee7e 100644 --- a/tests/entrypoints/test_ssl_cert_refresher.py +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -11,7 +11,6 @@ class MockSSLContext(SSLContext): - def __init__(self): self.load_cert_chain_count = 0 self.load_ca_count = 0 @@ -34,7 +33,7 @@ def load_verify_locations( def create_file() -> str: - with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as f: return f.name diff --git a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py index 1b95bf59f67c..4a9b8aa39d70 100644 --- a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py +++ b/tests/fastsafetensors_loader/test_fastsafetensors_loader.py @@ -17,7 +17,6 @@ def test_model_loader_download_files(vllm_runner): - with vllm_runner(test_model, - load_format=LoadFormat.FASTSAFETENSORS) as llm: + with vllm_runner(test_model, load_format=LoadFormat.FASTSAFETENSORS) as llm: deserialized_outputs = llm.generate(prompts, sampling_params) assert deserialized_outputs diff --git a/tests/fastsafetensors_loader/test_weight_utils.py b/tests/fastsafetensors_loader/test_weight_utils.py index 78d23acfec7c..cc899b77b5e9 100644 --- a/tests/fastsafetensors_loader/test_weight_utils.py +++ b/tests/fastsafetensors_loader/test_weight_utils.py @@ -8,24 +8,25 @@ import torch from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, fastsafetensors_weights_iterator, - safetensors_weights_iterator) + download_weights_from_hf, + fastsafetensors_weights_iterator, + safetensors_weights_iterator, +) def test_fastsafetensors_model_loader(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 fastsafetensors_tensors = {} hf_safetensors_tensors = {} - for name, tensor in fastsafetensors_weights_iterator( - safetensors, True): + for name, tensor in fastsafetensors_weights_iterator(safetensors, True): fastsafetensors_tensors[name] = tensor for name, tensor in safetensors_weights_iterator(safetensors, True): @@ -34,13 +35,10 @@ def test_fastsafetensors_model_loader(): assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors) for name, fastsafetensors_tensor in fastsafetensors_tensors.items(): - fastsafetensors_tensor = fastsafetensors_tensor.to('cpu') - assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[ - name].dtype - assert fastsafetensors_tensor.shape == hf_safetensors_tensors[ - name].shape - assert torch.all( - fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) + fastsafetensors_tensor = fastsafetensors_tensor.to("cpu") + assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[name].dtype + assert fastsafetensors_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) if __name__ == "__main__": diff --git a/tests/kernels/allclose_default.py b/tests/kernels/allclose_default.py index 9d65159bf64f..6561e9556fa7 100644 --- a/tests/kernels/allclose_default.py +++ b/tests/kernels/allclose_default.py @@ -6,11 +6,7 @@ # Reference default values of atol and rtol are from # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} -default_rtol = { - torch.float16: 1e-3, - torch.bfloat16: 1.6e-2, - torch.float: 1.3e-6 -} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} def get_default_atol(output) -> float: diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index 88a2fb62b254..b080a71bd54e 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,8 +3,7 @@ import pytest -from vllm.utils import (create_kv_caches_with_random, - create_kv_caches_with_random_flash) +from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash @pytest.fixture() diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 2e0b4efebfdb..e3823df74f16 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -6,9 +6,9 @@ import pytest import torch - from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops from vllm.attention.layer import Attention, MultiHeadAttention from vllm.platforms import current_platform @@ -30,9 +30,11 @@ PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = ( + [torch.half, torch.bfloat16, torch.float] + if not current_platform.is_rocm() + else [torch.half, torch.bfloat16] +) NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing @@ -45,9 +47,7 @@ USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( @@ -113,8 +113,7 @@ def ref_single_query_cached_kv_attention( # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) @@ -122,8 +121,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] +) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -146,13 +145,18 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if ((kv_cache_dtype == "fp8" and head_size % 16) - or (version == "rocm" and head_size not in (64, 128))): + if (kv_cache_dtype == "fp8" and head_size % 16) or ( + version == "rocm" and head_size not in (64, 128) + ): pytest.skip() - if (version == "rocm" and current_platform.is_navi() - and (kv_cache_dtype == "fp8" or head_size != 128 - or block_size != 16 or use_alibi)): + if ( + version == "rocm" + and current_platform.is_navi() + and ( + kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi + ) + ): pytest.skip() global PARTITION_SIZE @@ -180,18 +184,24 @@ def test_paged_attention( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) block_tables = torch.tensor(block_tables_lst, dtype=torch.int) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -217,18 +227,37 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + ( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) elif version in ("v2", "rocm"): if current_platform.is_rocm() and version == "rocm": PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -261,13 +290,34 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: ops.paged_attention_rocm( @@ -291,13 +341,30 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._rocm_C.paged_attention, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, None, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._rocm_C.paged_attention, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: raise AssertionError(f"Unknown version: {version}") @@ -306,18 +373,17 @@ def test_paged_attention( if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + dequantized_key_cache = torch.empty( + size=key_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) + dequantized_value_cache = torch.empty( + size=value_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache @@ -370,8 +436,9 @@ def ref_multi_query_kv_attention( if alibi_bias: attn_mask = alibi_bias[i] else: - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1 + ) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype) @@ -393,8 +460,9 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -416,13 +484,11 @@ def test_multi_query_kv_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) + qkv = torch.empty( + num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype + ) qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: @@ -432,8 +498,7 @@ def test_multi_query_kv_attention( alibi_bias = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, - seq_lens) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output = torch.empty_like(query) start = 0 # Dynamic sequence length not supported with custom attn_bias. @@ -445,7 +510,8 @@ def test_multi_query_kv_attention( value[None, start:end], attn_bias=attn_bias[i], p=0.0, - scale=scale) + scale=scale, + ) output[start:end].copy_(out.view_as(query[start:end])) start += seq_len # xformers.AttentionBias to Tensor for use in reference impl. @@ -488,8 +554,9 @@ def test_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention_with_alibi( num_seqs: int, diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 93bf20da4adb..bff77c7868a1 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -15,8 +15,7 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -37,7 +36,7 @@ def clear_cache(): "cuda": [16, 64], # CUDA supports both standard and extended block sizes "hip": [16, 1], # HIP requires special handling for block_size=1 # "cpu": [16] # CPU uses fixed block size from test cases - "cpu": [] # FIXME(woosuk): Temporarily disable CPU tests + "cpu": [], # FIXME(woosuk): Temporarily disable CPU tests } @@ -45,12 +44,13 @@ def generate_params(): params = [] for use_mla in [True, False]: for device in ["cuda", "hip", "cpu"]: - backends = DEVICE_MLA_BACKENDS[ - device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device] + backends = ( + DEVICE_MLA_BACKENDS[device] + if use_mla + else DEVICE_REGULAR_ATTN_BACKENDS[device] + ) for name in backends: - block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ - 16 - ] + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16] for block_size in block_sizes: params.append( pytest.param( @@ -58,14 +58,13 @@ def generate_params(): name, use_mla, block_size, - id= - f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" - )) + id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}", + ) + ) return params -@pytest.mark.parametrize("device, name, use_mla, block_size", - generate_params()) +@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) @pytest.mark.parametrize("use_v1", [True, False]) def test_env( device: str, @@ -85,57 +84,61 @@ def test_env( if not use_v1: pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - block_size, False) + with patch("vllm.attention.selector.current_platform", CpuPlatform()): + backend = get_attn_backend( + 16, torch.float16, torch.float16, block_size, False + ) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): + with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) + valid_combination = (name == "TRITON_MLA" and block_size != 1) or ( + name == "ROCM_AITER_MLA" and block_size == 1 + ) if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) expected = f"{name}_VLLM_V1" if use_v1 else name assert backend.get_name() == expected else: with pytest.raises(ValueError) as exc_info: - get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - assert f"The selected backend, {name}" in str( - exc_info.value) + get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) + assert f"The selected backend, {name}" in str(exc_info.value) else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" assert backend.get_name() == expected elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): + with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: if name == "FLASHMLA" and block_size == 64: from vllm.attention.backends.flashmla import ( - is_flashmla_supported) + is_flashmla_supported, + ) # only on cuda platforms with specific capability. is_supported, _ = is_flashmla_supported() @@ -144,53 +147,63 @@ def test_env( # if platform is not supported then skip this case. pytest.skip() else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) expected = f"{name}_VLLM_V1" if use_v1 else name assert backend.get_name() == expected else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = ("TRITON_MLA_VLLM_V1" - if use_v1 else "TRITON_MLA") + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) + expected = "TRITON_MLA_VLLM_V1" if use_v1 else "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) expected = "FLASHINFER_VLLM_V1" if use_v1 else name assert backend.get_name() == expected else: - backend = get_attn_backend(32, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 32, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name assert backend.get_name() == expected if use_v1: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) + backend = get_attn_backend( + 16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla, + ) assert backend.get_name() == "FLEX_ATTENTION", ( "Should fallback to FlexAttention if head size is " - "not supported by FlashAttention") + "not supported by FlashAttention" + ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -208,19 +221,14 @@ def test_fp32_fallback( if not use_v1: pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + with patch("vllm.attention.selector.current_platform", CpuPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, 16, False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) - assert (backend.get_name() == "FLEX_ATTENTION" - if use_v1 else "XFORMERS") + with patch("vllm.attention.selector.current_platform", CudaPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, 16, False) + assert backend.get_name() == "FLEX_ATTENTION" if use_v1 else "XFORMERS" def test_flash_attn(monkeypatch: pytest.MonkeyPatch): @@ -232,9 +240,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, - "get_device_capability", - lambda _=None: (7, 5)) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16, False) assert backend.get_name() != STR_FLASH_ATTN_VAL @@ -255,17 +261,17 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): # flash-attn is not installed import sys - original_module = sys.modules.get('vllm_flash_attn') - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) + + original_module = sys.modules.get("vllm_flash_attn") + monkeypatch.setitem(sys.modules, "vllm_flash_attn", None) backend = get_attn_backend(16, torch.float16, None, 16, False) assert backend.get_name() != STR_FLASH_ATTN_VAL # Restore the original module if it existed if original_module is not None: - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', - original_module) + monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module) else: - monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) + monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False) # Unsupported head size backend = get_attn_backend(17, torch.float16, None, 16, False) @@ -278,9 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("use_v1", [True, False]) def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): - - with monkeypatch.context() as m, patch( - "vllm.attention.selector.current_platform", CudaPlatform()): + with ( + monkeypatch.context() as m, + patch("vllm.attention.selector.current_platform", CudaPlatform()), + ): m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) diff --git a/tests/kernels/attention/test_blocksparse_attention.py b/tests/kernels/attention/test_blocksparse_attention.py index 9aee818c9956..e0b291f59e82 100644 --- a/tests/kernels/attention/test_blocksparse_attention.py +++ b/tests/kernels/attention/test_blocksparse_attention.py @@ -6,11 +6,12 @@ import pytest import torch - from tests.kernels.allclose_default import get_default_atol, get_default_rtol + from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( - LocalStridedBlockSparseAttn) + LocalStridedBlockSparseAttn, +) from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes @@ -34,7 +35,7 @@ USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] BLOCKSPARSE_LOCAL_BLOCKS = [16] BLOCKSPARSE_VERT_STRIDES = [8] @@ -111,8 +112,7 @@ def ref_single_query_cached_kv_attention( # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) if blocksparse_vert_stride >= 1: bsize = blocksparse_block_size @@ -120,19 +120,18 @@ def ref_single_query_cached_kv_attention( vert = blocksparse_vert_stride locals = blocksparse_local_blocks qb = (seq_len - 1) // bsize - attn_mask = q.new_zeros( - (num_query_heads, 1, seq_len)).float() - torch.inf + attn_mask = q.new_zeros((num_query_heads, 1, seq_len)).float() - torch.inf for h in range(num_query_heads): if hsliding >= 0: # slide with q heads bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 else: # slide with kv heads - bs_offset = (tp_rank * num_kv_heads + - h // num_queries_per_kv) * (-hsliding) + 1 + bs_offset = (tp_rank * num_kv_heads + h // num_queries_per_kv) * ( + -hsliding + ) + 1 for kb in range(qb + 1): kj = kb * bsize - if (qb - kb) < locals or \ - (kb + bs_offset) % vert == 0: - attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 + if (qb - kb) < locals or (kb + bs_offset) % vert == 0: + attn_mask[h, 0, kj : min(kj + bsize, seq_len)] = 0 if alibi_bias is not None: attn_mask += alibi_bias else: @@ -156,8 +155,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) @pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) @pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) -@pytest.mark.parametrize("blocksparse_head_sliding_step", - BLOCKSPARSE_HEADS_SLIDINGS) +@pytest.mark.parametrize("blocksparse_head_sliding_step", BLOCKSPARSE_HEADS_SLIDINGS) def test_paged_attention( kv_cache_factory, version: str, @@ -198,17 +196,23 @@ def test_paged_attention( block_tables = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -240,7 +244,7 @@ def test_paged_attention( blocksparse_head_sliding_step=blocksparse_head_sliding_step, ) elif version == "v2": - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -283,18 +287,17 @@ def test_paged_attention( if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + dequantized_key_cache = torch.empty( + size=key_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) + dequantized_value_cache = torch.empty( + size=value_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache @@ -346,8 +349,7 @@ def ref_multi_query_kv_attention( seq_len = end_idx - start_idx # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype) @@ -401,13 +403,11 @@ def test_varlen_blocksparse_attention_prefill( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) + qkv = torch.empty( + num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype + ) qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) bs_attn_op = LocalStridedBlockSparseAttn( num_query_heads, @@ -417,13 +417,10 @@ def test_varlen_blocksparse_attention_prefill( block_size=blocksparse_block_size, device=device, dtype=dtype, - homo_head=blocksparse_homo_heads) + homo_head=blocksparse_homo_heads, + ) - output = bs_attn_op(query, - key, - value, - cu_seq_lens.to(device), - sm_scale=scale) + output = bs_attn_op(query, key, value, cu_seq_lens.to(device), sm_scale=scale) if num_queries_per_kv > 1: # Handle MQA and GQA diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 789507615580..fad998b60120 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -5,12 +5,12 @@ import pytest import torch - from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck + from vllm import _custom_ops as ops from vllm.platforms import current_platform -COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] +COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing @@ -32,9 +32,7 @@ NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -83,24 +81,33 @@ def test_copy_blocks( block_mapping.append((src, dst2)) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, - num_layers, num_heads, - head_size, kv_cache_dtype, - dtype, seed, device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) - - opcheck(torch.ops._C_cache_ops.copy_blocks, - (key_caches, value_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(head_size == HEAD_SIZES[0])) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) + + opcheck( + torch.ops._C_cache_ops.copy_blocks, + (key_caches, value_caches, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + cond=(head_size == HEAD_SIZES[0]), + ) ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. @@ -113,8 +120,7 @@ def test_copy_blocks( # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): torch.testing.assert_close(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, - cloned_value_caches): + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): torch.testing.assert_close(value_cache, cloned_value_cache) @@ -153,10 +159,17 @@ def test_reshape_and_cache( _, key, value = qkv.unbind(dim=1) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -174,12 +187,30 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + opcheck( + torch.ops._C_cache_ops.reshape_and_cache, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) @@ -200,14 +231,12 @@ def test_reshape_and_cache( cloned_value_cache[block_idx, :, :, block_offset] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache, cloned_key_cache) torch.testing.assert_close(value_cache, cloned_value_cache) @@ -247,15 +276,8 @@ def test_reshape_and_cache_flash( # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) _, key, value = qkv.unbind(dim=1) # Create the KV caches. @@ -286,40 +308,57 @@ def permute_and_compact(x): # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache_compact, - v_scale.item(), kv_cache_dtype) + cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype + ) + cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype + ) else: cloned_key_cache = key_cache_compact.clone() cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, v_scale) + opcheck( + torch.ops._C_cache_ops.reshape_and_cache_flash, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_key_cache, - key_cache_compact, - k_scale.item(), - kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_value_cache, - value_cache_compact, - v_scale.item(), - kv_dtype=kv_cache_dtype) + result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype + ) + result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_value_cache, + value_cache_compact, + v_scale.item(), + kv_dtype=kv_cache_dtype, + ) # Run the reference implementation. block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") @@ -337,14 +376,12 @@ def permute_and_compact(x): cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache_compact, cloned_key_cache) torch.testing.assert_close(value_cache_compact, cloned_value_cache) @@ -381,8 +418,8 @@ def test_swap_blocks( current_platform.seed_everything(seed) - src_device = device if direction[0] == "cuda" else 'cpu' - dst_device = device if direction[1] == "cuda" else 'cpu' + src_device = device if direction[0] == "cuda" else "cpu" + dst_device = device if direction[1] == "cuda" else "cpu" src_blocks = random.sample(range(num_blocks), num_mappings) # For the same device, mapping must not overlap @@ -393,42 +430,62 @@ def test_swap_blocks( dst_blocks = random.sample(range(num_blocks), num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, src_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + src_device, + ) # Create the KV caches on the second device. dist_key_caches, dist_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, dst_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + dst_device, + ) src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - do_opcheck = (head_size == HEAD_SIZES[0]) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), - cond=do_opcheck) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), - cond=do_opcheck) - - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], - block_mapping_tensor) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping_tensor) + do_opcheck = head_size == HEAD_SIZES[0] + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) for src, dst in block_mapping: - torch.testing.assert_close(src_key_caches_clone[src].cpu(), - dist_key_caches[0][dst].cpu()) - torch.testing.assert_close(src_value_caches_clone[src].cpu(), - dist_value_caches[0][dst].cpu()) + torch.testing.assert_close( + src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() + ) + torch.testing.assert_close( + src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() + ) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -474,11 +531,9 @@ def _create_mla_cache( device: str, ) -> torch.Tensor: cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype - return torch.zeros(num_blocks, - block_size, - entry_size, - dtype=cache_dtype, - device=device) + return torch.zeros( + num_blocks, block_size, entry_size, dtype=cache_dtype, device=device + ) def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): @@ -518,20 +573,16 @@ def test_concat_and_cache_mla( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -543,10 +594,7 @@ def test_concat_and_cache_mla( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -556,24 +604,18 @@ def test_concat_and_cache_mla( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) if kv_cache_dtype == "fp8": result_temp = torch.empty_like(kv_cache, dtype=torch.float16) - ops.convert_fp8(result_temp, - kv_cache.contiguous(), - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8( + result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype + ) expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) - ops.convert_fp8(expected_temp, - ref_kv_cache, - scale.item(), - kv_dtype=kv_cache_dtype) - torch.testing.assert_close(result_temp, - expected_temp, - atol=0.001, - rtol=0.1) + ops.convert_fp8( + expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype + ) + torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1) else: torch.testing.assert_close(kv_cache, ref_kv_cache) @@ -606,8 +648,9 @@ def test_copy_blocks_mla( kv_caches = [] for _ in range(num_layers): - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) kv_caches.append(kv_cache) @@ -624,9 +667,9 @@ def test_copy_blocks_mla( dst2 = dst_blocks[2 * i + 1] block_mapping.append((src, dst1)) block_mapping.append((src, dst2)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) for src, dst in block_mapping: for ref_cache in ref_caches: @@ -667,10 +710,12 @@ def test_swap_blocks_mla( entry_size = kv_lora_rank + qk_rope_head_dim - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) - dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) + dst_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype) @@ -682,9 +727,9 @@ def test_swap_blocks_mla( remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remaining_blocks, num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) opcheck( torch.ops._C_cache_ops.swap_blocks, @@ -699,7 +744,8 @@ def test_swap_blocks_mla( src_cache_clone[src].cpu(), dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " - f"{dst} in dst_cache.") + f"{dst} in dst_cache.", + ) @pytest.mark.parametrize("kv_lora_rank", [512]) @@ -709,42 +755,46 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("max_seq_len", [512]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("kv_cache_dtype", - ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize( + "kv_cache_dtype", ["auto"] +) # You can also test "fp8" if needed. @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_cache_mla( + kv_lora_rank, + qk_rope_head_dim, + block_size, + num_blocks, + max_seq_len, + batch_size, + dtype, + kv_cache_dtype, + device, +): entry_size = kv_lora_rank + qk_rope_head_dim - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(0, - max_seq_len + 1, (batch_size, ), - device=device) + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) total_tokens = seq_len_tensor.sum() - cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=device) + cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size - block_table = torch.empty((batch_size, num_blocks), - dtype=torch.int32, - device=device) + block_table = torch.empty( + (batch_size, num_blocks), dtype=torch.int32, device=device + ) for b in range(batch_size): perm = torch.randperm(num_blocks, device=device) block_table[b, :] = perm - dst = torch.zeros((total_tokens, entry_size), - dtype=src_cache.dtype, - device=device) + dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device) expected_batches = [] for b in range(batch_size): @@ -800,20 +850,16 @@ def test_concat_and_cache_mla_cpu( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -825,10 +871,7 @@ def test_concat_and_cache_mla_cpu( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -838,6 +881,5 @@ def test_concat_and_cache_mla_cpu( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) torch.testing.assert_close(kv_cache, ref_kv_cache) diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 1e7e7e0a7f84..58e8bd592ba4 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -7,11 +7,12 @@ import torch from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import (cascade_attention, - merge_attn_states) -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] @@ -37,21 +38,14 @@ def test_merge_kernel( assert num_query_heads % num_kv_heads == 0 # Prepare inputs. - prefix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) - suffix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) + prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) + suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) # Run the kernel. output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype) - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) # Reference implementation. max_lse = torch.maximum(prefix_lse, suffix_lse) @@ -97,8 +91,10 @@ def test_cascade( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) current_platform.seed_everything(0) @@ -107,11 +103,9 @@ def test_cascade( num_query_heads = num_heads[0] num_kv_heads = num_heads[1] assert num_query_heads % num_kv_heads == 0 - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) seq_lens, common_prefix_len = seq_lens_and_common_prefix @@ -122,26 +116,21 @@ def test_cascade( max_kv_len = max(kv_lens) total_num_query_tokens = sum(query_lens) - query = torch.randn(total_num_query_tokens, - num_query_heads, - head_size, - dtype=dtype) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) assert common_prefix_len > 0 assert common_prefix_len % block_size == 0 num_common_kv_blocks = common_prefix_len // block_size # Make sure the first `num_common_kv_blocks` blocks are the same. - block_tables[:, :num_common_kv_blocks] = \ - block_tables[0, :num_common_kv_blocks] + block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks] # Run the regular attention. ref_output = flash_attn_varlen_func( @@ -161,8 +150,7 @@ def test_cascade( # Run cascade attention. assert all(common_prefix_len < kv_len for kv_len in kv_lens) - cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], - dtype=torch.int32) + cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32) suffix_kv_lens = kv_lens_tensor - common_prefix_len output = torch.empty_like(query) diff --git a/tests/kernels/attention/test_encoder_decoder_attn.py b/tests/kernels/attention/test_encoder_decoder_attn.py index a2e698646090..d82f28155bb5 100644 --- a/tests/kernels/attention/test_encoder_decoder_attn.py +++ b/tests/kernels/attention/test_encoder_decoder_attn.py @@ -13,12 +13,15 @@ import pytest import torch - from tests.kernels.utils import * + from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) +from vllm.attention.selector import ( + _Backend, + _cached_get_attn_backend, + global_force_attn_backend_context_manager, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.platforms import current_platform @@ -27,10 +30,10 @@ @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch): """ - Encoder-decoder is only supported on V0, so set + Encoder-decoder is only supported on V0, so set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") # List of support backends for encoder/decoder models @@ -79,7 +82,7 @@ class TestPoint(NamedTuple): class TestResources(NamedTuple): - ''' + """ Encapsulates key components for performing an encoder/decoder attention test @@ -105,15 +108,17 @@ class TestResources(NamedTuple): i.e. XFormers * attn: Attention layer instance * kv_cache: shared key/value cache for all attention - ''' + """ scale: float attn: Attention kv_cache: torch.Tensor -def _make_test_resources(test_pt: TestPoint, ) -> TestResources: - ''' +def _make_test_resources( + test_pt: TestPoint, +) -> TestResources: + """ Build key components for performing encoder/decoder attention test. Note that @@ -137,7 +142,7 @@ class that Attention will automatically select when it is constructed. Returns: * TestResources data structure. - ''' + """ scale = float(1.0 / (test_pt.head_size**0.5)) attn = Attention( @@ -150,18 +155,19 @@ class that Attention will automatically select when it is constructed. if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache return TestResources( - scale, attn, - torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) + scale, attn, torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE) + ) # Construct KV cache - if test_pt.attn_type in (AttentionType.DECODER, - AttentionType.ENCODER_DECODER): - kv_cache = make_kv_cache(test_pt.num_blocks, - test_pt.num_heads, - test_pt.head_size, - test_pt.block_size, - device=CUDA_DEVICE, - backend=test_pt.backend_name) + if test_pt.attn_type in (AttentionType.DECODER, AttentionType.ENCODER_DECODER): + kv_cache = make_kv_cache( + test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, + device=CUDA_DEVICE, + backend=test_pt.backend_name, + ) else: kv_cache = torch.tensor([]) @@ -173,7 +179,7 @@ def _encoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, ) -> PhaseTestParameters: - ''' + """ Set up test vectors & data structures for encoder attention test. A triplet of synthetic query/key/value tensors are constructed. @@ -200,7 +206,7 @@ def _encoder_attn_setup( * PhaseTestParameters data structure comprising (1) packed query/key/value tensors, (2) the ideal output of attention computed using a naive implementation, and (3) KVCache field set to None - ''' + """ ( num_heads, @@ -220,33 +226,37 @@ def _encoder_attn_setup( # Make test tensors - qkv_in, _, _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) + qkv_in, _, _ = make_qkv( + batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE, + ) # Compute correct answer using naive non-causal attention # implementation - ideal_output = ref_masked_attention(qkv_in.query, - qkv_in.key, - qkv_in.value, - scale=scale, - q_seq_lens=qkv_in.q_seq_lens, - kv_seq_lens=qkv_in.kv_seq_lens) + ideal_output = ref_masked_attention( + qkv_in.query, + qkv_in.key, + qkv_in.value, + scale=scale, + q_seq_lens=qkv_in.q_seq_lens, + kv_seq_lens=qkv_in.kv_seq_lens, + ) - packed_ideal_output, _ = pack_tensor(ideal_output, - qkv_in.q_seq_lens, - device=CUDA_DEVICE) + packed_ideal_output, _ = pack_tensor( + ideal_output, qkv_in.q_seq_lens, device=CUDA_DEVICE + ) packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return PhaseTestParameters( PackedQKVO(packed_qkv, packed_ideal_output), - None # No KV cache + None, # No KV cache ) @@ -255,7 +265,7 @@ def _decoder_attn_setup( test_rsrcs: TestResources, block_base_addr: int = 0, ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: - ''' + """ Set up test vectors & data structures for self-attention test. A triplet of synthetic query/key/value tensors are constructed ("baseline" @@ -309,7 +319,7 @@ def _decoder_attn_setup( (intended to be used as the base address for the encoder/ decoder cross-attention block-table, which is not constructed in this function) - ''' + """ ( num_heads, @@ -333,27 +343,30 @@ def _decoder_attn_setup( qkv, prefill_qkv, decode_qkv, - ) = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) + ) = make_qkv( + batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE, + ) # Compute correct answer using naive attention implementation # with causal attention mask - causal_mask = make_causal_mask(max_q_seq_len, - max_kv_seq_len).to(CUDA_DEVICE) + causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention(qkv.query, - qkv.key, - qkv.value, - scale=scale, - custom_mask=causal_mask, - q_seq_lens=qkv.q_seq_lens, - kv_seq_lens=qkv.kv_seq_lens) + ideal_output = ref_masked_attention( + qkv.query, + qkv.key, + qkv.value, + scale=scale, + custom_mask=causal_mask, + q_seq_lens=qkv.q_seq_lens, + kv_seq_lens=qkv.kv_seq_lens, + ) # Split out the prefill- & decode-phase ideal answers & pack them @@ -361,16 +374,18 @@ def _decoder_attn_setup( decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_qkv.q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) + bdx, :prefill_q_seq_len + ] + decode_ideal_output[bdx, :] = ideal_output[ + bdx, prefill_q_seq_len : (prefill_q_seq_len + 1) + ] + + prefill_packed_ideal_output, _ = pack_tensor( + prefill_ideal_output, prefill_qkv.q_seq_lens, device=CUDA_DEVICE + ) + decode_packed_ideal_output, _ = pack_tensor( + decode_ideal_output, [1 for _ in range(batch_size)], device=CUDA_DEVICE + ) # Build prefill- & decode-phase data structures # for decoder self-attention. Block tables and @@ -398,17 +413,14 @@ def _decoder_attn_setup( decode_block_tables, slot_mapping_list, max_block_idx, - ) = make_block_tables_slot_mapping(block_size, - qkv.q_seq_lens, - device=CUDA_DEVICE, - block_base_addr=block_base_addr) + ) = make_block_tables_slot_mapping( + block_size, qkv.q_seq_lens, device=CUDA_DEVICE, block_base_addr=block_base_addr + ) ( prefill_slot_mapping, decode_slot_mapping, - ) = split_slot_mapping(slot_mapping_list, - qkv.q_seq_lens, - device=CUDA_DEVICE) + ) = split_slot_mapping(slot_mapping_list, qkv.q_seq_lens, device=CUDA_DEVICE) prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) @@ -418,11 +430,14 @@ def _decoder_attn_setup( qkv, PhaseTestParameters( # Prefill test params PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping), + ), PhaseTestParameters( # Decode test params PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping)), - max_block_idx) + KVMemoryMap(decode_block_tables, decode_slot_mapping), + ), + max_block_idx, + ) def _enc_dec_cross_attn_setup_reuses_query( @@ -433,7 +448,7 @@ def _enc_dec_cross_attn_setup_reuses_query( test_rsrcs: TestResources, block_base_addr: int = 0, ) -> tuple[PhaseTestParameters, PhaseTestParameters]: - ''' + """ Set up test vectors & data structures for cross-attention test. A triplet of synthetic cross-attention key/value tensors are constructed @@ -494,7 +509,7 @@ def _enc_dec_cross_attn_setup_reuses_query( along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for decode phase. - ''' + """ assert encoder_test_params.packed_qkvo.packed_qkv is not None assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None @@ -517,7 +532,8 @@ def _enc_dec_cross_attn_setup_reuses_query( decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens prefill_q_seq_lens = ( - prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + ) assert prefill_q_seq_lens is not None @@ -525,36 +541,42 @@ def _enc_dec_cross_attn_setup_reuses_query( cross_kv, _, _, - ) = make_qkv(batch_size, - max_decoder_seq_len, - max_encoder_seq_len, - num_heads, - head_size, - force_kv_seq_lens=encoder_seq_lens, - attn_type=AttentionType.ENCODER_DECODER, - device=CUDA_DEVICE) - - ideal_output = ref_masked_attention(decoder_query, - cross_kv.key, - cross_kv.value, - scale=scale, - q_seq_lens=decoder_seq_lens, - kv_seq_lens=cross_kv.kv_seq_lens) + ) = make_qkv( + batch_size, + max_decoder_seq_len, + max_encoder_seq_len, + num_heads, + head_size, + force_kv_seq_lens=encoder_seq_lens, + attn_type=AttentionType.ENCODER_DECODER, + device=CUDA_DEVICE, + ) + + ideal_output = ref_masked_attention( + decoder_query, + cross_kv.key, + cross_kv.value, + scale=scale, + q_seq_lens=decoder_seq_lens, + kv_seq_lens=cross_kv.kv_seq_lens, + ) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) + bdx, :prefill_q_seq_len + ] + decode_ideal_output[bdx, :] = ideal_output[ + bdx, prefill_q_seq_len : (prefill_q_seq_len + 1) + ] + + prefill_packed_ideal_output, _ = pack_tensor( + prefill_ideal_output, prefill_q_seq_lens, device=CUDA_DEVICE + ) + decode_packed_ideal_output, _ = pack_tensor( + decode_ideal_output, [1 for _ in range(batch_size)], device=CUDA_DEVICE + ) # Build prefill- & decode-phase data structures # for encoder/decoder cross-attention. Block tables and @@ -591,13 +613,16 @@ def _enc_dec_cross_attn_setup_reuses_query( decode_block_tables, prefill_slot_mapping_list, _, - ) = make_block_tables_slot_mapping(block_size, - cross_kv.kv_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) + ) = make_block_tables_slot_mapping( + block_size, + cross_kv.kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE, + ) - prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, - device=CUDA_DEVICE) + prefill_slot_mapping = maybe_make_long_tensor( + prefill_slot_mapping_list, device=CUDA_DEVICE + ) # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) @@ -605,10 +630,13 @@ def _enc_dec_cross_attn_setup_reuses_query( return ( PhaseTestParameters( # Prefill-phase test params PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping), + ), PhaseTestParameters( # Decode-phase test params PackedQKVO(None, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping))) + KVMemoryMap(decode_block_tables, decode_slot_mapping), + ), + ) def _run_encoder_attention_test( @@ -618,7 +646,7 @@ def _run_encoder_attention_test( test_pt: TestPoint, vllm_config: VllmConfig, ) -> torch.Tensor: - ''' + """ Run encoder attention. attn.forward() is passed attn_type=AttentionType.ENCODER in order @@ -641,7 +669,7 @@ def _run_encoder_attention_test( Returns: * Attention.forward() applied to packed {query,key,value} and & attn_metadata - ''' + """ assert attn_metadata.num_decode_tokens == 0 packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None @@ -654,7 +682,8 @@ def _run_encoder_attention_test( # TODO - Update the way we construct the query so that it # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) + -1, test_pt.num_heads * test_pt.head_size + ) return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) @@ -665,7 +694,7 @@ def _run_decoder_self_attention_test( test_pt: TestPoint, vllm_config: VllmConfig, ) -> torch.Tensor: - ''' + """ Run decoder self-attention test. attn.forward() is passed attn_type=AttentionType.DECODER @@ -687,7 +716,7 @@ def _run_decoder_self_attention_test( Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata - ''' + """ attn = test_rsrcs.attn packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None @@ -700,7 +729,8 @@ def _run_decoder_self_attention_test( # TODO - Update the way we construct the query so that it # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) + -1, test_pt.num_heads * test_pt.head_size + ) return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) @@ -712,7 +742,7 @@ def _run_encoder_decoder_cross_attention_test( test_pt: TestPoint, vllm_config: VllmConfig, ) -> torch.Tensor: - ''' + """ Run encoder/decoder cross-attention test. Via PhaseTestParameters data structures, consumes the same query utilized @@ -745,7 +775,7 @@ def _run_encoder_decoder_cross_attention_test( Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata - ''' + """ assert decoder_test_params.packed_qkvo.packed_qkv is not None attn = test_rsrcs.attn @@ -754,8 +784,8 @@ def _run_encoder_decoder_cross_attention_test( value = None else: cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) - value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) + key = None if cross_pckd_qkv is None else cross_pckd_qkv.key + value = None if cross_pckd_qkv is None else cross_pckd_qkv.value with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However @@ -765,7 +795,8 @@ def _run_encoder_decoder_cross_attention_test( # TODO - Update the way we construct the query so that it # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) + -1, test_pt.num_heads * test_pt.head_size + ) return attn.forward(reshaped_query, key, value) @@ -775,7 +806,7 @@ def set_reset_environment(attn_backend): # testing of the Flash Attention backend. Also clear the # cached value of the backend. default_dtype = torch.get_default_dtype() - if attn_backend.name == 'FLASH_ATTN': + if attn_backend.name == "FLASH_ATTN": torch.set_default_dtype(torch.bfloat16) _cached_get_attn_backend.cache_clear() yield @@ -784,8 +815,7 @@ def set_reset_environment(attn_backend): torch.set_default_dtype(default_dtype) -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.skipif(current_platform.is_rocm(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @@ -802,7 +832,7 @@ def test_encoder_only( max_dec_seq_len: int, max_enc_seq_len: int, ): - ''' + """ End-to-end encoder-only attention test: * Construct fake test vectors for (1) encoder attention @@ -830,15 +860,23 @@ def test_encoder_only( * block_size: KV cache block size * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences - ''' + """ # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test - test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) + test_pt = TestPoint( + num_heads, + head_size, + attn_backend.name, + batch_size, + block_size, + max_dec_seq_len, + max_enc_seq_len, + 4096, + AttentionType.ENCODER, + ) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -860,24 +898,26 @@ def test_encoder_only( decoder_test_params=None, encoder_test_params=enc_test_params, cross_test_params=None, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + ) # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( + enc_pckd_act_out: torch.Tensor = _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, prephase_attn_metadata, test_pt=test_pt, - vllm_config=vllm_config)) + vllm_config=vllm_config, + ) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + enc_test_params, enc_pckd_act_out, attn_backend.name + ) -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.skipif(current_platform.is_rocm(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @@ -894,7 +934,7 @@ def test_e2e_enc_dec_attn( max_dec_seq_len: int, max_enc_seq_len: int, ) -> None: - ''' + """ End-to-end encoder/decoder test: * Construct fake test vectors for (1) encoder attention, @@ -954,22 +994,45 @@ def test_e2e_enc_dec_attn( * block_size: KV cache block size * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences - ''' + """ # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test - enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) - enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, - AttentionType.ENCODER_DECODER) - dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.DECODER) + enc_test_pt = TestPoint( + num_heads, + head_size, + attn_backend.name, + batch_size, + block_size, + max_dec_seq_len, + max_enc_seq_len, + 4096, + AttentionType.ENCODER, + ) + enc_dec_test_pt = TestPoint( + num_heads, + head_size, + attn_backend.name, + batch_size, + block_size, + max_dec_seq_len, + max_enc_seq_len, + 4096, + AttentionType.ENCODER_DECODER, + ) + dec_test_pt = TestPoint( + num_heads, + head_size, + attn_backend.name, + batch_size, + block_size, + max_dec_seq_len, + max_enc_seq_len, + 4096, + AttentionType.DECODER, + ) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -1010,7 +1073,8 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, enc_dec_test_pt, enc_dec_test_rsrcs, - block_base_addr=cross_block_base_addr) + block_base_addr=cross_block_base_addr, + ) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None @@ -1021,19 +1085,23 @@ def test_e2e_enc_dec_attn( decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + ) # PREFILL: encoder attention - enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata, - test_pt=enc_test_pt, - vllm_config=vllm_config) + enc_pckd_act_out = _run_encoder_attention_test( + enc_test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=enc_test_pt, + vllm_config=vllm_config, + ) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + enc_test_params, enc_pckd_act_out, attn_backend.name + ) # PREFILL: decoder self-attention test @@ -1042,12 +1110,13 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, prephase_attn_metadata, test_pt=dec_test_pt, - vllm_config=vllm_config) + vllm_config=vllm_config, + ) # - Is prefill decoder self-attention correct? - assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + prephase_dec_test_params, prephase_dec_pckd_act_out, attn_backend.name + ) # PREFILL: encoder/decoder cross-attention test @@ -1057,12 +1126,13 @@ def test_e2e_enc_dec_attn( prephase_cross_test_params, prephase_attn_metadata, test_pt=enc_dec_test_pt, - vllm_config=vllm_config) + vllm_config=vllm_config, + ) # - Is prefill encoder/decoder cross-attention correct? - assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + prephase_cross_test_params, prephase_cross_pckd_act_out, attn_backend.name + ) # DECODE: build decode-phase attention metadata @@ -1073,7 +1143,8 @@ def test_e2e_enc_dec_attn( decoder_test_params=decphase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=decphase_cross_test_params, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + ) # DECODE: decoder self-attention test @@ -1082,12 +1153,13 @@ def test_e2e_enc_dec_attn( decphase_dec_test_params, decphase_attn_metadata, test_pt=dec_test_pt, - vllm_config=vllm_config) + vllm_config=vllm_config, + ) # - Is decode-phase decoder self-attention correct? - assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + decphase_dec_test_params, decphase_dec_pckd_act_out, attn_backend.name + ) # DECODE: encoder/decoder cross-attention test @@ -1097,9 +1169,10 @@ def test_e2e_enc_dec_attn( None, decphase_attn_metadata, test_pt=enc_dec_test_pt, - vllm_config=vllm_config) + vllm_config=vllm_config, + ) # - Is decode-phase encoder/decoder cross-attention correct? - assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out, - attn_backend.name) + assert_actual_matches_ideal( + decphase_cross_test_params, decphase_cross_pckd_act_out, attn_backend.name + ) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index bd3190d09b0f..ba1a540a1a0e 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -7,10 +7,12 @@ import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -42,7 +44,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -60,10 +62,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -104,11 +109,15 @@ def test_flash_attn_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -117,23 +126,19 @@ def test_flash_attn_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) q = query.unsqueeze(1) out = torch.empty_like(q) if use_out else None @@ -178,23 +183,27 @@ def test_flash_attn_with_paged_kv( if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -220,11 +229,15 @@ def test_varlen_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -234,30 +247,23 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) out = torch.empty_like(query) if use_out else None @@ -313,5 +319,7 @@ def test_varlen_with_paged_kv( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 3ad6e1d32911..19e7ad67dfdb 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -36,7 +36,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -54,10 +54,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -97,20 +100,16 @@ def test_flashinfer_decode_with_paged_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -131,35 +130,41 @@ def test_flashinfer_decode_with_paged_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_tensor_cores=((num_query_heads // num_kv_heads) > 4), + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, key_value_cache) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) + ( + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) @@ -169,11 +174,14 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_prefill_with_paged_kv( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) num_seqs = len(seq_lens) @@ -185,16 +193,10 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], max_kv_len = max(kv_lens) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) @@ -204,10 +206,9 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], value_cache /= head_size**0.5 max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -231,8 +232,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -252,16 +252,20 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], key_value_cache, ) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) @@ -271,9 +275,13 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) def test_flashinfer_prefill_with_paged_fp8_kv( - seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: pytest.skip("TODO: fix the accuracy issue") torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -288,17 +296,11 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_cache_dtype = torch.float8_e4m3fn - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -306,15 +308,15 @@ def test_flashinfer_prefill_with_paged_fp8_kv( k_scale = key_cache.amax().item() / 448.0 v_scale = value_cache.amax().item() / 448.0 - kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], - dim=1).to(kv_cache_dtype) + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to( + kv_cache_dtype + ) - assert (kv_cache_fp8.shape == key_value_cache.shape) + assert kv_cache_fp8.shape == key_value_cache.shape max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -338,8 +340,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -356,19 +357,23 @@ def test_flashinfer_prefill_with_paged_fp8_kv( output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache.squeeze(1), - value_cache=value_cache.squeeze(1), - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) del query del block_tables # verify prefill fp8 - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @@ -401,12 +406,9 @@ def test_flashinfer_decode_with_paged_fp8_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -416,14 +418,13 @@ def test_flashinfer_decode_with_paged_fp8_kv( key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) - assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1 kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -444,32 +445,38 @@ def test_flashinfer_decode_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=use_tensor_cores) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py index 96eee13695a9..e22e893acf8c 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py @@ -9,8 +9,9 @@ from vllm.platforms import current_platform if not current_platform.is_device_capability(100): - pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", - allow_module_level=True) + pytest.skip( + "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True + ) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -72,10 +73,9 @@ def test_flashinfer_trtllm_decode_with_baseline( key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) k_scale = v_scale = 1.0 kv_indptr = [0] kv_indices = [] @@ -96,30 +96,30 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout, - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_query_heads // num_kv_heads) > 4), + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, key_value_cache, scale) # TRTLLM Decode max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, - dtype=torch.int, - device=query.device) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=query.device) output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache( query.contiguous(), key_value_cache, @@ -136,5 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline( v_scale, ) - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" + ( + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - output_trtllm))}", + ) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 21b08e45fd6f..43b027ba8226 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -7,24 +7,28 @@ import pytest import torch -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported, +) from vllm.triton_utils import triton def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: x, y = x.double(), y.double() - cos_diff = 1 - 2 * (x * y).sum().item() / max( - (x * x + y * y).sum().item(), 1e-12) + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) assert cos_diff < 1e-5 -FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ - if not is_flashmla_supported()[0] else "FlashMLA is supported" +FLASH_MLA_UNSUPPORTED_REASON = ( + is_flashmla_supported()[1] + if not is_flashmla_supported()[0] + else "FlashMLA is supported" +) -@pytest.mark.skipif(not is_flashmla_supported()[0], - reason=FLASH_MLA_UNSUPPORTED_REASON) + +@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192]) @@ -36,8 +40,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen): # TODO: parametrize using pytest dtype = torch.bfloat16 device = torch.device("cuda:0") @@ -47,30 +50,32 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, torch.manual_seed(0) random.seed(0) - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) - cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), - s_q) + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) total_seqlens = cache_seqlens.sum().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 q = torch.randn(b, s_q, h_q, d) - block_table = torch.arange(b * max_seqlen_pad // block_size, - dtype=torch.int32).view( - b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, - d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv) + cache_seqlens, s_q * h_q // h_kv, h_kv + ) def flash_mla(): return flash_mla_with_kvcache( @@ -95,8 +100,7 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, - dtype=torch.bool).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -127,7 +131,7 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(dtype).bits // 8 + ) + print(f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS, {bytes / 10**6 / t:.0f} GB/s") diff --git a/tests/kernels/attention/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py index de45ee1ed5cc..0e3da986299e 100644 --- a/tests/kernels/attention/test_lightning_attn.py +++ b/tests/kernels/attention/test_lightning_attn.py @@ -4,8 +4,7 @@ import pytest import torch -from vllm.model_executor.layers.lightning_attn import ( - linear_decode_forward_triton) +from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton from vllm.platforms import current_platform NUM_HEADS = [4, 8] @@ -17,8 +16,8 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): """Reference implementation of lightning attention core algorithm - - The difference from the main implementation is that this processes + + The difference from the main implementation is that this processes each step sequentially, instead of using parallelized triton kernels """ B, H, S, D = q.shape @@ -62,8 +61,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): # The actual implementation returns a tensor of shape [B, H, 2, D, E] # where dimension 2 contains both KV and KV history kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] - final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], - dim=2) # [B, H, 2, D, E] + final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E] return output, final_kv_cache @@ -109,7 +107,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): out_h = torch.matmul(q_bh, kv_new) # Update output and cache - output[b, h * D:(h + 1) * D] = out_h + output[b, h * D : (h + 1) * D] = out_h kv_caches[b, h] = kv_new return output @@ -135,12 +133,9 @@ def test_linear_decode_forward_triton( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -150,15 +145,14 @@ def test_linear_decode_forward_triton( slot_idx = torch.arange(batch_size, device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) - torch.testing.assert_close(triton_output, - reference_output, - rtol=1e-1, - atol=1e-1) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) + torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1) torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -184,12 +178,9 @@ def test_linear_decode_forward_triton_with_padding( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -199,14 +190,15 @@ def test_linear_decode_forward_triton_with_padding( slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) - padding_mask = (slot_idx - != -1).unsqueeze(1).expand(-1, num_heads * head_size) + padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size) triton_masked = triton_output[padding_mask] reference_masked = reference_output[padding_mask] @@ -217,15 +209,11 @@ def test_linear_decode_forward_triton_with_padding( for i in range(batch_size): if valid_indices[i] > 0: - torch.testing.assert_close(kv_caches[i], - kv_caches_copy[i], - rtol=rtol, - atol=atol) + torch.testing.assert_close( + kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol + ) - torch.testing.assert_close(triton_masked, - reference_masked, - rtol=rtol, - atol=atol) + torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -249,39 +237,33 @@ def test_lightning_attention_reference( current_platform.seed_everything(42) base = 0.01 - q = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - k = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - v = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) + q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) ed = torch.zeros(num_heads, device="cuda") for h in range(num_heads): ed[h] = 0.1 * (h + 1) - kv_history = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_history = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_history_clone = kv_history.clone() ref_output, ref_kv_cache = reference_lightning_attention( - q, k, v, ed, 256, kv_history) + q, k, v, ed, 256, kv_history + ) from vllm.model_executor.layers.lightning_attn import lightning_attention + actual_output, actual_kv_cache = lightning_attention( - q, k, v, ed, 256, kv_history_clone) + q, k, v, ed, 256, kv_history_clone + ) atol, rtol = 1.5e-1, 1.5e-1 torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) - torch.testing.assert_close(ref_kv_cache, - actual_kv_cache, - rtol=rtol, - atol=atol) + torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol) assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) assert ref_kv_cache.shape == actual_kv_cache.shape diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index 9d1a301ebe30..eb9204dfaf15 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -7,19 +7,20 @@ from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states as merge_attn_states_triton) + merge_attn_states as merge_attn_states_triton, +) from vllm.platforms import current_platform # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # can be used to combine partial attention results (in the split-KV case) def merge_attn_states_torch( - output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] + output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] ): p_lse = prefix_lse s_lse = suffix_lse @@ -32,15 +33,13 @@ def merge_attn_states_torch( s_lse = s_lse - max_lse p_lse_exp = torch.exp(p_lse) s_lse_exp = torch.exp(s_lse) - out_se = (p_lse_exp + s_lse_exp) + out_se = p_lse_exp + s_lse_exp if output_lse is not None: output_lse = torch.log(out_se) + max_lse p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] - p_scale = torch.transpose(p_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - s_scale = torch.transpose(s_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] output = prefix_output * p_scale + suffix_output * s_scale return output, output_lse @@ -55,8 +54,10 @@ def merge_attn_states_torch( def generate_markdown_table(): global all_case_info - table_header = ("| tokens | heads | headsize | dtype " - "| device | torch | triton | cuda | speedup |") + table_header = ( + "| tokens | heads | headsize | dtype " + "| device | torch | triton | cuda | speedup |" + ) table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |" def shortly_dtype(dtype: torch.dtype) -> str: @@ -68,16 +69,26 @@ def shortly_device(device: str) -> str: print(table_header) print(table_separator) for info in all_case_info: - (num_tokens, num_heads, head_size, dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved) = info + ( + num_tokens, + num_heads, + head_size, + dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) = info dtype = shortly_dtype(dtype) device = shortly_device(device) - print(f"| {num_tokens} | {num_heads} | {head_size} " - f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " - f"| {avg_time_triton_kernel:.5f}ms " - f"| {avg_time_cuda_kernel:.5f}ms " - f"| {performance_improved:.4f}x |") + print( + f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " + f"| {avg_time_triton_kernel:.5f}ms " + f"| {avg_time_cuda_kernel:.5f}ms " + f"| {performance_improved:.4f}x |" + ) @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @@ -85,29 +96,28 @@ def shortly_device(device: str) -> str: @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("output_dtype", DTYPES) @torch.inference_mode() -def test_merge_attn_states(num_tokens: int, num_query_heads: int, - head_size: int, output_dtype: torch.dtype): +def test_merge_attn_states( + num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype +): if not current_platform.is_cuda(): - pytest.skip('Currently only support compare triton merge_attn_states ' - 'with custom cuda merge_attn_states kernel') + pytest.skip( + "Currently only support compare triton merge_attn_states " + "with custom cuda merge_attn_states kernel" + ) NUM_TOKENS = num_tokens NUM_HEADS = num_query_heads HEAD_SIZE = head_size - print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " - f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " - f"Device: {current_platform.get_device_name()}") + print( + f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {current_platform.get_device_name()}" + ) # prefix_lse and suffix_lse contain inf and normal values - prefix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") - suffix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") + prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") # Generate boolean masks mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 @@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) - prefix_lse[mask_prefix] = float('inf') - suffix_lse[mask_suffix] = float('inf') + prefix_lse[mask_prefix] = float("inf") + suffix_lse[mask_suffix] = float("inf") # Other input tensors (need to be initialized but # no actual calculation needed) - output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS), - dtype=torch.float32, - device="cuda") - prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") + output = torch.zeros( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + output_lse = torch.zeros( + (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" + ) + prefix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + suffix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) warmup_times = 2 repeat_times = 20 @@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, suffix_lse_torch = suffix_lse.clone() for _ in range(warmup_times): output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) end.record() torch.cuda.synchronize() total_time_torch_kernel += start.elapsed_time(end) @@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, end = torch.cuda.Event(enable_timing=True) for _ in range(warmup_times): - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) end.record() torch.cuda.synchronize() total_time_triton_kernel += start.elapsed_time(end) @@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, output_lse_cuda = output_lse.clone() for _ in range(warmup_times): - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) end.record() torch.cuda.synchronize() total_time_cuda_kernel += start.elapsed_time(end) @@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel print(f" Torch time: {avg_time_torch_kernel:.6f}ms") print(f"Triton time: {avg_time_triton_kernel:.6f}ms") - print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " - f"Performance: {performance_improved:.5f}x") + print( + f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " + f"Performance: {performance_improved:.5f}x" + ) print("-" * 100) # 4. Correctness compare @@ -232,35 +276,45 @@ def diff(a: torch.Tensor, b: torch.Tensor): # states operation. output_ref = output_ref_triton output_lse_ref = output_lse_ref_triton - torch.testing.assert_close(output_cuda.float(), - output_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol + ) print("Output all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}") print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}") print("-" * 100) - torch.testing.assert_close(output_lse_cuda.float(), - output_lse_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) print("Output LSE all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}") print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}") print("-" * 100) - print("All output values test passed! All inf values " - "are correctly replaced with -inf.") + print( + "All output values test passed! All inf values " + "are correctly replaced with -inf." + ) print("-" * 100) device = current_platform.get_device_name() all_case_info.append( - (NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved)) - if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * - len(NUM_QUERY_HEADS) * len(DTYPES)): + ( + NUM_TOKENS, + NUM_HEADS, + HEAD_SIZE, + output_dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) + ) + if len(all_case_info) == ( + len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) + ): generate_markdown_table() diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 53c37554b15a..f97dd50cb5ca 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -5,6 +5,7 @@ * Tests for MultiHeadAttention layer """ + from unittest.mock import patch import pytest @@ -20,8 +21,7 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -74,9 +74,11 @@ def ref_attention( NUM_KV_HEADS = [1] HEAD_SIZES = [64, 80] # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = ( + [torch.half, torch.bfloat16, torch.float] + if not current_platform.is_rocm() + else [torch.half, torch.bfloat16] +) CUDA_DEVICES = ["cuda"] @@ -104,10 +106,9 @@ def test_mha_attn_forward( k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads) + attn = MultiHeadAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) output = attn(q, k, v) assert num_heads % num_kv_heads == 0 diff --git a/tests/kernels/attention/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py index f8b307c595de..44f3e42e8714 100644 --- a/tests/kernels/attention/test_mla_decode_cpu.py +++ b/tests/kernels/attention/test_mla_decode_cpu.py @@ -11,30 +11,24 @@ def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @@ -63,18 +57,17 @@ def test_mla_decode_cpu( torch.set_default_dtype(dtype) torch.manual_seed(0) - scale = d**(-0.5) + scale = d ** (-0.5) if varlen: seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) seq_lens = seq_lens.clip(2).to(torch.int32) else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) max_seq_len = seq_lens.max().item() seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary? q = torch.randn(bs, h_q, d) - block_table = torch.arange(bs * seqlen_pad // block_size, - dtype=torch.int32) + block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32) block_table = block_table.view(bs, seqlen_pad // block_size) kv_cache = torch.randn(block_table.numel(), block_size, d) @@ -82,8 +75,7 @@ def test_mla_decode_cpu( kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan") out_mla = q.new_zeros(bs, h_q, dv) - ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, - seq_lens) + ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens) out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index b09e1bbc4279..5acab484ecdc 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -12,8 +12,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from vllm.attention.backends.xformers import _make_alibi_bias -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -22,9 +21,7 @@ NUM_QUERIES_PER_KV = [1, 8, 64] HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] @@ -50,12 +47,10 @@ def test_contexted_kv_attention( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -93,38 +88,29 @@ def test_contexted_kv_attention( cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -135,61 +121,71 @@ def test_contexted_kv_attention( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) @@ -201,22 +197,24 @@ def test_contexted_kv_attention( # heads. # # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + query = query.view( + query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] + ) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) + query_lens, seq_lens + ) if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright( - sliding_window) + attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, @@ -239,7 +237,7 @@ def test_contexted_kv_attention( ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") output_ref = output_ref.reshape(output.shape) atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -262,12 +260,10 @@ def test_contexted_kv_attention_alibi( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -280,9 +276,9 @@ def test_contexted_kv_attention_alibi( def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -290,17 +286,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes alibi_slopes = _get_alibi_slopes(num_heads).to(device) @@ -328,38 +323,29 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -370,82 +356,90 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) # NOTE(DefTruth): In order to reuse _make_alibi_bias function, # we have to pad query tensor before MQA/GQA expanding. if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) + query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) query_pad.uniform_(-1e-3, 1e-3) seq_start = 0 query_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) + query_pad[seq_start:seq_end, ...] = torch.cat( + [ + torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...], + ], + dim=0, + ) seq_start += seq_len query_start += query_len query = query_pad @@ -456,11 +450,12 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # heads. # # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) # [seq, num_kv_heads, num_queries_per_kv, dk]=> # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the # codebase. We save some time reshaping alibi matrix at runtime. @@ -483,24 +478,23 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale, + ) out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, - ...]) + seq_len, num_heads, head_size + ) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) seq_start += seq_len query_start += query_len torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -532,9 +526,16 @@ def test_contexted_kv_attention_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, - sliding_window, dtype, kv_cache_dtype, device, - op) + test_contexted_kv_attention( + num_heads, + num_queries_per_kv, + head_size, + sliding_window, + dtype, + kv_cache_dtype, + device, + op, + ) @pytest.mark.optional @@ -555,5 +556,6 @@ def test_contexted_kv_attention_alibi_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, - dtype, kv_cache_dtype, device, op) + test_contexted_kv_attention_alibi( + num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op + ) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 34311b9ccd76..68b0fa5e6838 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -11,8 +11,7 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -21,38 +20,42 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", - RocmPlatform()) + monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) # Test standard ROCm attention backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "TRITON_ATTN_VLLM_V1") + assert ( + backend.get_name() == "ROCM_FLASH" + or backend.get_name() == "TRITON_ATTN_VLLM_V1" + ) # MLA test for deepseek related # change the attention backend to triton MLA m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) - assert (backend.get_name() == "TRITON_MLA" - or backend.get_name() == "TRITON_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) + assert ( + backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1" + ) # If attention backend is None # If use_mla is true # The selected backend is triton MLA m.setenv(STR_BACKEND_ENV_VAR, None) - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) - assert (backend.get_name() == "TRITON_MLA" - or backend.get_name() == "TRITON_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, False, True) + assert ( + backend.get_name() == "TRITON_MLA" + or backend.get_name() == "TRITON_MLA_VLLM_V1" + ) # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) - assert (backend.get_name() == "ROCM_AITER_MLA" - or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) + assert ( + backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1" + ) # If attention backend is None # If use_mla is true @@ -60,7 +63,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # The selected backend is ROCM_AITER_MLA m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) - assert (backend.get_name() == "ROCM_AITER_MLA" - or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) + assert ( + backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1" + ) diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index 2dca720fe330..b893a4b820d9 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) - req_to_page = torch.randint(0, - CACHE_SIZE // PAGE_SIZE, - (B, num_pages_per_batch, 1), - device="cuda") + req_to_page = torch.randint( + 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" + ) req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) - req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( - 1, 1, -1) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() @@ -46,7 +44,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): # o will have the same shape as q o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - b_seq_len = torch.full((B, ), seq_len, device="cuda") + b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 0cb7f5963c79..290f858959f7 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -14,9 +14,11 @@ BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] -QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ - None, torch.float8_e4m3fnuz -] +QDTYPES = ( + [None, torch.float8_e4m3fn] + if not current_platform.is_rocm() + else [None, torch.float8_e4m3fnuz] +) # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] @@ -42,7 +44,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -60,10 +62,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None and soft_cap > 0: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -77,9 +82,9 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -114,30 +119,23 @@ def test_triton_unified_attn( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) output = torch.empty_like(query) @@ -191,5 +189,7 @@ def test_triton_unified_attn( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8ba8..fdb1c8adfd6e 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -5,27 +5,30 @@ import pytest import torch - from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, - GeluAndMul, MulAndSilu, - NewGELU, QuickGELU, - SiluAndMul) + +from vllm.model_executor.layers.activation import ( + FastGELU, + FatreluAndMul, + GeluAndMul, + MulAndSilu, + NewGELU, + QuickGELU, + SiluAndMul, +) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize( - "activation", - ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) + "activation", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"] +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -67,7 +70,7 @@ def test_act_and_mul( torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) @@ -75,9 +78,14 @@ def test_act_and_mul( opcheck(fn, (out, x)) -@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), - (NewGELU, torch.ops._C.gelu_new), - (QuickGELU, torch.ops._C.gelu_quick)]) +@pytest.mark.parametrize( + "activation", + [ + (FastGELU, torch.ops._C.gelu_fast), + (NewGELU, torch.ops._C.gelu_new), + (QuickGELU, torch.ops._C.gelu_quick), + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -99,10 +107,9 @@ def test_activation( fn = activation[1] out = layer(x) ref_out = layer.forward_native(x) - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=get_default_rtol(out)) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) + ) out = torch.empty_like(x) opcheck(fn, (out, x)) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 19703b8a2f97..60467f696693 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -5,9 +5,9 @@ import pytest import torch +from tests.kernels.utils import opcheck import vllm._custom_ops as ops -from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm DTYPES = [torch.bfloat16, torch.float] @@ -24,9 +24,7 @@ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] EPS = 1e-6 @@ -34,13 +32,12 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_rms_norm(rms_norm_layer: RMSNorm, - x: torch.Tensor, - residual: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, Optional[torch.Tensor]]: +def ref_rms_norm( + rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor] +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -50,12 +47,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm, return out, residual -def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ref_dynamic_per_token_quant( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -64,9 +62,9 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, # Quant if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant(torch_out, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) else: assert quant_dtype == torch.int8 torch_out, scales = ops.scaled_int8_quant(torch_out) @@ -74,38 +72,41 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, return torch_out, scales, residual -def ref_impl(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, - residual, scale_ub) +def ref_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ref_dynamic_per_token_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub + ) -def ops_dynamic_per_token_quant(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ops_dynamic_per_token_quant( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, - quant_dtype, scale_ub, - residual) + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual -def ops_impl(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, - scale_ub) +def ops_impl( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @@ -146,12 +147,14 @@ def test_rms_norm( residual = torch.randn_like(x) * scale if add_residual else None if scale_ub is not None: rms_x, _ = ref_rms_norm(layer, x, residual) - scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') + scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") - ref_out, ref_scales, ref_residual = \ - ref_impl(layer, x, quant_dtype, residual, scale_ub) - ops_out, ops_scales, ops_residual = \ - ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + ref_out, ref_scales, ref_residual = ref_impl( + layer, x, quant_dtype, residual, scale_ub + ) + ops_out, ops_scales, ops_residual = ops_impl( + layer.weight, x, quant_dtype, residual, scale_ub + ) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype @@ -160,15 +163,18 @@ def test_rms_norm( # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) if add_residual: assert torch.allclose(ref_residual, ops_residual) output = torch.empty_like(x, dtype=quant_dtype) - scales = torch.empty((x.numel() // x.shape[-1], 1), - device=x.device, - dtype=torch.float32) - - opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, - (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + + opcheck( + torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual), + ) diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 3eac062738f8..37116060bd1a 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -3,21 +3,30 @@ import pytest import torch - from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck + from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, - 8199] # Arbitrary values for testing +HIDDEN_SIZES = [ + 8, + 768, + 769, + 770, + 771, + 5120, + 5124, + 5125, + 5126, + 8192, + 8199, +] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -58,11 +67,14 @@ def test_rms_norm( torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) if residual is not None: - opcheck(torch.ops._C.fused_add_rms_norm, - (x, residual, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.fused_add_rms_norm, + (x, residual, layer.weight.data, layer.variance_epsilon), + ) else: - opcheck(torch.ops._C.rms_norm, - (out, x, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon) + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -102,36 +114,38 @@ def test_fused_rms_norm_quant( if add_residual: torch.ops._C.fused_add_rms_norm_static_fp8_quant( - out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6) + out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6 + ) # Unfused kernel is in-place so it goes second # Also use a separate clone of x to avoid modifying the input x_unfused = x.clone() torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, - quant_scale_t) + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, quant_scale_t) torch.cuda.synchronize() - torch.testing.assert_close(residual_fused, - residual, - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) opcheck( torch.ops._C.fused_add_rms_norm_static_fp8_quant, - (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6), + ) else: - torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, - quant_scale_t, 1e-6) + torch.ops._C.rms_norm_static_fp8_quant( + out_quant_fused, x, weight, quant_scale_t, 1e-6 + ) torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, - quant_scale_t) - - opcheck(torch.ops._C.rms_norm_static_fp8_quant, - (out_quant_fused, x, weight, quant_scale_t, 1e-6)) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t) - torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), - out_quant.to(dtype=torch.float32), - atol=1e-3, - rtol=1e-3) + opcheck( + torch.ops._C.rms_norm_static_fp8_quant, + (out_quant_fused, x, weight, quant_scale_t, 1e-6), + ) + + torch.testing.assert_close( + out_quant_fused.to(dtype=torch.float32), + out_quant.to(dtype=torch.float32), + atol=1e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/core/test_opcheck.py b/tests/kernels/core/test_opcheck.py index 40ced08b933a..6d52669a7a25 100644 --- a/tests/kernels/core/test_opcheck.py +++ b/tests/kernels/core/test_opcheck.py @@ -5,7 +5,6 @@ """ import torch - from tests.kernels.utils import opcheck diff --git a/tests/kernels/core/test_permute_cols.py b/tests/kernels/core/test_permute_cols.py index e18f6230dbce..1470301ee1d4 100644 --- a/tests/kernels/core/test_permute_cols.py +++ b/tests/kernels/core/test_permute_cols.py @@ -3,16 +3,16 @@ import pytest import torch - from tests.kernels.utils import opcheck + from vllm._custom_ops import permute_cols -@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) -@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_permute_cols(shape, dtype): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() opcheck(torch.ops._C.permute_cols, (x, perm)) y = permute_cols(x, perm) - torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file + torch.testing.assert_close(y, x[:, perm]) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index ab6f1ccf881f..d2643c23cc17 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -6,8 +6,8 @@ import pytest import torch - from tests.kernels.allclose_default import get_default_atol, get_default_rtol + from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -19,30 +19,33 @@ BATCH_SIZES = [5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] USE_KEY = [True, False] -def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_flat_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads * head_size) # For testing sliced tensors -def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_padded_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size + 64) -def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_batch_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) TENSORS_SHAPES_FN = [ - _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape + _get_batch_tensor_shape, + _get_flat_tensor_shape, + _get_padded_tensor_shape, ] @@ -97,18 +100,21 @@ def test_rotary_embedding( ref_query, ref_key = rope.forward_native(positions, query, key) out_query, out_key = rope.forward(positions, query, key) # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) + torch.testing.assert_close( + out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query), + ) if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + torch.testing.assert_close( + out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key), + ) else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" + assert ref_key is None and out_key is None, "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -142,10 +148,14 @@ def test_batched_rotary_embedding( torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": (1, ) - }) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": (1,)}, + ) rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) @@ -160,25 +170,28 @@ def test_batched_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) - out_query, out_key = rope.forward(positions, - query, - key, - offsets=torch.zeros(batch_size * seq_len, - dtype=torch.long, - device=device)) + out_query, out_key = rope.forward( + positions, + query, + key, + offsets=torch.zeros(batch_size * seq_len, dtype=torch.long, device=device), + ) # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) + torch.testing.assert_close( + out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query), + ) if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + torch.testing.assert_close( + out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key), + ) else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" + assert ref_key is None and out_key is None, "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -211,72 +224,98 @@ def test_batched_rotary_embedding_multi_lora( if rotary_dim is None: rotary_dim = head_size scaling_factors: list[int] = [1, 2, 4] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": tuple(scaling_factors)}, + ) rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) key = torch.randn_like(query) if use_key else None offset_map = torch.tensor( list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) + accumulate( + [0] + + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ] + ) + ) + ) + query_types = torch.randint( + 0, len(scaling_factors), (batch_size, seq_len), device=device + ) query_offsets = offset_map[query_types] # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key, - query_offsets) - out_query, out_key = rope.forward(positions, query, key, - query_offsets.flatten()) + ref_query, ref_key = rope.forward_native(positions, query, key, query_offsets) + out_query, out_key = rope.forward(positions, query, key, query_offsets.flatten()) # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) + torch.testing.assert_close( + out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query), + ) if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + torch.testing.assert_close( + out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key), + ) else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" + assert ref_key is None and out_key is None, "expected returned key to be None" @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] BASES = [10000, 1000000] - ROPE_SCALINGS = (None, { - "rope_type": "linear", - "factor": (1, ) - }, { - "rope_type": "dynamic", - "factor": 1 - }) - settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, - ROPE_SCALINGS, DTYPES) + ROPE_SCALINGS = ( + None, + {"rope_type": "linear", "factor": (1,)}, + {"rope_type": "dynamic", "factor": 1}, + ) + settings = ( + HEAD_SIZES, + ROTARY_DIMS, + MAX_POSITIONS, + BASES, + IS_NEOX_STYLE, + ROPE_SCALINGS, + DTYPES, + ) rope_setting_id_map: dict[str, int] = {} for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # different settings cannot share the same rope module assert id(rope) not in rope_setting_id_map.values() assert all(x.dtype == dtype for x in rope.buffers()) @@ -284,11 +323,25 @@ def test_rope_module_cache(): rope_setting_id_map[str(setting)] = id(rope) for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # check if cache take effect assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index d1fd960bf115..4a46cc6dc6b9 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -8,28 +8,41 @@ import pytest import torch - from tests.kernels.utils import opcheck + from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -def rotary_embedding_opcheck(rot, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None): +def rotary_embedding_opcheck( + rot, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, +): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: - opcheck(torch.ops._C.batched_rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style, rot.rotary_dim, offsets)) + opcheck( + torch.ops._C.batched_rotary_embedding, + ( + positions, + query, + key, + rot.head_size, + cos_sin_cache, + rot.is_neox_style, + rot.rotary_dim, + offsets, + ), + ) else: - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + opcheck( + torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, rot.is_neox_style), + ) @pytest.mark.parametrize("device", ["cuda"]) @@ -40,39 +53,44 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False]) -def test_rotary_embedding_opcheck(dist_init, device, max_position, - is_neox_style, rotary_dim, head_size, - seq_len, use_key, head_stride_is_contiguous): +def test_rotary_embedding_opcheck( + dist_init, + device, + max_position, + is_neox_style, + rotary_dim, + head_size, + seq_len, + use_key, + head_stride_is_contiguous, +): batch_size = 1 base = 10000 num_heads = 7 - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) + rot = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, torch.float32 + ) - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device=device) + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) head_stride = head_size + (64 if head_stride_is_contiguous else 0) - query = torch.randn(batch_size, - seq_len, - num_heads, - head_stride, - dtype=torch.float32, - device=device) + query = torch.randn( + batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device + ) key = torch.randn_like(query) if use_key else None query = query[..., :head_size] key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) - offsets = torch.zeros(batch_size * seq_len, - device=device, - dtype=torch.long) + offsets = torch.zeros(batch_size * seq_len, device=device, dtype=torch.long) rotary_embedding_opcheck(rot, positions, query, key, offsets) # if we have a contiguous head stride, test the alternate # [..., num_heads * head_dim] shape/layout if head_stride_is_contiguous: rotary_embedding_opcheck( - rot, positions, query.flatten(start_dim=-2), - key.flatten(start_dim=-2) if use_key else None) + rot, + positions, + query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None, + ) diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index c71215e4c646..73738175e5c7 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -5,20 +5,14 @@ from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") @pytest.mark.parametrize("device", CUDA_DEVICES) def test_cpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -40,11 +34,7 @@ def test_cpu_write(device): @pytest.mark.parametrize("device", CUDA_DEVICES) def test_gpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -59,4 +49,4 @@ def test_gpu_write(device): assert cpu_tensor[0, 0] == 2 assert cpu_tensor[2, 3] == 4 - assert cpu_tensor[4, 5] == -2 \ No newline at end of file + assert cpu_tensor[4, 5] == -2 diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index 411bd9e904b0..f5bac4f1ac12 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -10,7 +10,9 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.platforms import current_platform @@ -39,18 +41,15 @@ def causal_conv1d_ref( seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) + dtype_in + ) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: @@ -59,12 +58,9 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None, - cache_seqlens=None): +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -91,24 +87,25 @@ def causal_conv1d_update_ref(x, assert weight.shape == (dim, width) if cache_seqlens is None: x_new = torch.cat([conv_state, x], dim=-1).to( - weight.dtype) # (batch, dim, state_len + seqlen) + weight.dtype + ) # (batch, dim, state_len + seqlen) conv_state.copy_(x_new[:, :, -state_len:]) else: width_idx = torch.arange( - -(width - 1), 0, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand( - -1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], - dim=-1).to(weight.dtype) - copy_idx = torch.arange( - seqlen, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, - state_len).unsqueeze(1).expand(-1, dim, -1) + -(width - 1), 0, dtype=torch.long, device=x.device + ).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = ( + torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + ) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze( + 0 + ) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, - groups=dim)[:, :, -seqlen:] + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] if unsqueeze: out = out.squeeze(-1) return (out if activation is None else F.silu(out)).to(dtype=dtype_in) @@ -117,15 +114,17 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID): +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -150,8 +149,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("seqlen", [1]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, - itype): +def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,23 +165,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation) - out_ref = causal_conv1d_update_ref(x_ref, - conv_state_ref, - weight, - bias, - activation=activation) + out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) + out_ref = causal_conv1d_update_ref( + x_ref, conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("seqlen", [1, 3]) @@ -192,9 +183,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("batch_size", [3]) -def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, - width, seqlen, has_bias, - silu_activation, itype): +def test_causal_conv1d_update_with_batch_gather( + batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -209,31 +200,30 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, total_entries = 10 * batch_size # x will be (batch, dim, seqlen) with contiguous along dim-axis - x = torch.randn(padded_batch_size, seqlen, dim, device=device, - dtype=itype).transpose(1, 2) + x = torch.randn( + padded_batch_size, seqlen, dim, device=device, dtype=itype + ).transpose(1, 2) x_ref = x.clone() conv_state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[conv_state_indices] = False - padded_state_indices = torch.concat([ - conv_state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) # conv_state will be (cache_lines, dim, state_len) # with contiguous along dim-axis - conv_state = torch.randn(total_entries, - width - 1, - dim, - device=device, - dtype=itype).transpose(1, 2) + conv_state = torch.randn( + total_entries, width - 1, dim, device=device, dtype=itype + ).transpose(1, 2) conv_state_for_padding_test = conv_state.clone() @@ -242,22 +232,23 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation, - conv_state_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID) - out_ref = causal_conv1d_update_ref(x_ref[:batch_size], - conv_state_ref, - weight, - bias, - activation=activation) + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref( + x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.equal(conv_state[unused_states_bool], - conv_state_for_padding_test[unused_states_bool]) + assert torch.equal( + conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool] + ) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @@ -265,12 +256,13 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096]) -@pytest.mark.parametrize('dim', [64, 4096]) -@pytest.mark.parametrize('with_padding', [True, False]) -@pytest.mark.parametrize('batch', [4, 10]) -def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, - has_bias, silu_activation, itype): +@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("dim", [64, 4096]) +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch", [4, 10]) +def test_causal_conv1d_varlen( + batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype +): device = "cuda" torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -288,19 +280,19 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) x = rearrange( torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), - "b s d -> b d s")[:, 4096:4096 + dim, :] + "b s d -> b d s", + )[:, 4096 : 4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) @@ -309,34 +301,34 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(total_entries, - width - 1, - dim, - device=x.device, - dtype=x.dtype).transpose(1, 2) + final_states = torch.randn( + total_entries, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) final_states_ref = final_states.clone() - has_initial_states = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=x.device) - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=x.device)[:batch_size] - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) - out = causal_conv1d_fn(x.squeeze(0), - weight, - bias=bias, - conv_states=final_states, - query_start_loc=cumsum.cuda(), - cache_indices=padded_state_indices, - has_initial_state=has_initial_states, - activation=activation, - pad_slot_id=PAD_SLOT_ID) + has_initial_states = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ + :batch_size + ] + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + out = causal_conv1d_fn( + x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID, + ) out_ref = [] out_ref_b = [] @@ -353,16 +345,20 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[ - padded_state_indices[i]].unsqueeze(0), - initial_states=final_states_ref[padded_state_indices[i]]. - unsqueeze(0) if has_initial_states[i] else None)) + final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None, + ) + ) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref_tensor = torch.cat(out_ref, dim=0) - assert torch.allclose(final_states[state_indices], - final_states_ref[state_indices], - rtol=rtol, - atol=atol) - unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose( + final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol, + ) + unpadded_out = out[:, : out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index f5c6a18614ff..e05592f5678e 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -5,10 +5,12 @@ import pytest import torch - from tests.utils import multi_gpu_test -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) + +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -24,14 +26,15 @@ (64, 2), (64, 4), # hidden_size be divisible by num_gpus (100, 5), # and n_groups must divide hidden_size - ]) + ], +) @pytest.mark.parametrize("dtype", [torch.float16]) def test_mixer2_gated_norm_multi_gpu( batch_size: int, seq_len: int, hidden_size_n_groups: tuple[int, int], dtype: torch.dtype, - device: str = 'cuda', + device: str = "cuda", ): hidden_size, n_groups = hidden_size_n_groups num_processes = 2 @@ -39,17 +42,19 @@ def test_mixer2_gated_norm_multi_gpu( def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=( - num_processes, - batch_size, - seq_len, - hidden_size, - n_groups, - dtype, - device, - ), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs, + ) run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) @@ -71,20 +76,22 @@ def mixer2_gated_norm_tensor_parallel( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) # create random weights an inputs - weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + weight = torch.rand((hidden_size,), dtype=dtype, device=device) hidden_states = torch.randn(batch_size, seq_len, hidden_size) gate_states = torch.randn(batch_size, seq_len, hidden_size) @@ -97,14 +104,18 @@ def mixer2_gated_norm_tensor_parallel( # create gated-norm without TP to compute reference # - utilize mock patching to disable TP when - with (unittest.mock.patch( + with ( + unittest.mock.patch( "vllm.model_executor.layers.mamba.mamba_mixer2." "get_tensor_model_parallel_world_size", - return_value=1), - unittest.mock.patch( - "vllm.model_executor.layers.mamba.mamba_mixer2." - "get_tensor_model_parallel_rank", - return_value=0)): + return_value=1, + ), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0, + ), + ): mixer_single_gpu = Mixer2RMSNormGated( full_hidden_size=hidden_size, full_n_groups=n_groups, @@ -115,11 +126,13 @@ def mixer2_gated_norm_tensor_parallel( # generate and compare N = hidden_size // world_size output = mixer( - hidden_states[..., local_rank * N:(local_rank + 1) * N], - gate_states[..., local_rank * N:(local_rank + 1) * N], + hidden_states[..., local_rank * N : (local_rank + 1) * N], + gate_states[..., local_rank * N : (local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.allclose(output, - ref_output[..., local_rank * N:(local_rank + 1) * N], - atol=1e-3, - rtol=1e-3) + torch.allclose( + output, + ref_output[..., local_rank * N : (local_rank + 1) * N], + atol=1e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 8dece26ddb29..a338c95c7c88 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -5,25 +5,20 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_scan_fn, + selective_state_update, +) from vllm.platforms import current_platform -def selective_state_update_ref(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False): +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -73,16 +68,17 @@ def selective_state_update_ref(state, assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * - A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) + dA = torch.exp( + rearrange(dt, "b h d -> b h d 1") * A + ) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) dB = rearrange(dt, "b h d -> b h d 1") * rearrange( - B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_(state * dA + - dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + B, "b h n -> b h 1 n" + ) # (batch, nheads, dim, dstate) + state.copy_( + state * dA + dB * rearrange(x, "b h d -> b h d 1") + ) # (batch, dim, dstate out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) @@ -92,18 +88,20 @@ def selective_state_update_ref(state, return out -def selective_scan_ref(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - return_last_state=False, - prev_state=None, - final_state_out=None): +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + prev_state=None, + final_state_out=None, +): """ u: r(B D L) delta: r(B D L) @@ -132,26 +130,26 @@ def selective_scan_ref(u, C = C.float() x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) else: if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) + y = torch.einsum("bdn,dn->bd", x, C) else: if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) if i == u.shape[2] - 1: if final_state_out is None: final_state_out = x @@ -166,20 +164,22 @@ def selective_scan_ref(u, return out if not return_last_state else (out, final_state_out) -def selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - cu_seq_len=None, - cache_indices=None, - has_initial_state=None, - ssm_states=None, - pad_slot_id=PAD_SLOT_ID): +def selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + cu_seq_len=None, + cache_indices=None, + has_initial_state=None, + ssm_states=None, + pad_slot_id=PAD_SLOT_ID, +): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -206,30 +206,55 @@ def selective_scan_opcheck_fn(u, # Disable test_autograd_registration for now as it seems to trigger # a bogus error. - opcheck(torch.ops._C.selective_scan_fwd, - (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states, pad_slot_id), - test_utils=["test_schema", "test_faketensor"]) - - -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) + opcheck( + torch.ops._C.selective_scan_fwd, + ( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cu_seq_len, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ), + test_utils=["test_schema", "test_faketensor"], + ) + + +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("scan_chunks", [1, 2, 3]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, - has_z, has_delta_bias, delta_softplus, seqlen, itype, - wtype, scan_chunks): +def test_selective_scan( + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + seqlen, + itype, + wtype, + scan_chunks, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -242,7 +267,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, batch_size = 1 dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() if not is_variable_B: B_shape = [dim, dstate] @@ -250,9 +275,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B_shape = [batch_size, dstate, seqlen] else: B_shape = [batch_size, varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() if not is_variable_C: C_shape = [dim, dstate] @@ -260,27 +283,27 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, C_shape = [batch_size, dstate, seqlen] else: C_shape = [batch_size, varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() - z = torch.randn(batch_size, dim, seqlen, device=device, - dtype=itype) if has_z else None + z = ( + torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + if has_z + else None + ) z_ref = z.clone() if has_z else None - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * - torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() state_shape = (batch_size, u.shape[1], int(A.shape[1])) - state = torch.randn(state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False) state_ref = state.clone() out = None out_ref = None @@ -312,9 +335,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=_z, delta_bias=delta_bias, delta_softplus=delta_softplus, - has_initial_state=torch.ones(batch_size, - device=u.device, - dtype=torch.bool) if c > 0 else None) + has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) + if c > 0 + else None, + ) outs.append(out) if len(outs) > 1: out = torch.cat(outs, dim=-1) @@ -329,27 +353,29 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=z_ref, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=True) + return_last_state=True, + ) assert out is not None and out_ref is not None assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert state is not None and state_ref is not None assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D, - z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - ssm_states=state) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + ssm_states=state, + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @@ -373,51 +399,47 @@ def test_selective_state_update(dim, dstate, has_z, itype): D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None state_ref = state.detach().clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + out = selective_state_update( + state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32]) -@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32]) +@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("return_last_state", [True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [False, True]) -def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, - varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, - itype, wtype): +def test_selective_scan_varlen( + with_padding, + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + return_last_state, + seqlen, + itype, + wtype, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -441,72 +463,79 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0).cuda() + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() B_shape = [varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() C_shape = [varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() z = torch.randn(dim, seqlen, device=device, dtype=itype) z_ref = z.clone() - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() out = None out_ref = None prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) - prev_state = torch.randn(prev_state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + prev_state = torch.randn( + prev_state_shape, device=u.device, dtype=itype, requires_grad=False + ) prev_state_ref = prev_state.clone() - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=u.device)[:batch_size] - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[ + :batch_size + ] + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) - - has_initial_state = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=u.device) - out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + + has_initial_state = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device + ) + out = selective_scan_fn( + u, + prev_state, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + ) outs_ref = [] splits = [ torch.split(var, seqlens[0], dim=-1) @@ -528,33 +557,46 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, delta_softplus=delta_softplus, return_last_state=return_last_state, prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) - if has_initial_state[i] else None, - final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( - 0)) + if has_initial_state[i] + else None, + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0), + ) outs_ref.append(out_ref_s) out_ref = torch.cat(outs_ref, dim=-1)[0] - unpadded_out = out[:, :out_ref[0].shape[-1]] + unpadded_out = out[:, : out_ref[0].shape[-1]] print("Output diff max", (unpadded_out - out_ref).max()) print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state, prev_state) - - -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + prev_state, + ) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) -def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, - has_z, itype): +def test_selective_state_update_with_batch_indices( + with_padding, dim, dstate, has_z, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -569,17 +611,17 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 @@ -590,60 +632,59 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].clone() state_before = state.clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID) - out_ref = selective_state_update_ref(state_ref, - x[:batch_size], - dt[:batch_size], - A, - B[:batch_size], - C[:batch_size], - D=D, - z=z[:batch_size], - dt_bias=dt_bias, - dt_softplus=True) + out = selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = selective_state_update_ref( + state_ref, + x[:batch_size], + dt[:batch_size], + A, + B[:batch_size], + C[:batch_size], + D=D, + z=z[:batch_size], + dt_bias=dt_bias, + dt_softplus=True, + ) print("Output diff max", (out[:batch_size] - out_ref).max()) print("Output diff mean", (out[:batch_size] - out_ref).mean()) print("Output state diff max", (state[state_indices, :] - state_ref).max()) - print("Output state diff mean", - (state[state_indices, :] - state_ref).mean()) + print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) # test padded entries stay the same if with_padding: - assert torch.equal(state_before[unused_states_bool], - state[unused_states_bool]) - assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) - assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) - assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) - assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + assert torch.equal(state_before[unused_states_bool], state[unused_states_bool]) + assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :]) + assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :]) + assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :]) + assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :]) # test "real" entries - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("tie_hdim", [False, True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 4096]) def test_selective_state_update_with_heads_with_batch_indices( - dim, dstate, ngroups, has_z, tie_hdim, itype): + dim, dstate, ngroups, has_z, tie_hdim, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) if itype == torch.bfloat16: @@ -655,69 +696,53 @@ def test_selective_state_update_with_heads_with_batch_indices( nheads = dim // headdim total_entries = 10 * batch_size - state = torch.randn(total_entries, - nheads, - headdim, - dstate, - dtype=itype, - device=device) + state = torch.randn( + total_entries, nheads, headdim, dstate, dtype=itype, device=device + ) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) + dtype=torch.int32, device=device + ) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) if not tie_hdim: - dt = torch.randn(batch_size, - nheads, - headdim, - device=device, - dtype=itype) + dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 D = torch.randn(nheads, headdim, device=device) else: - dt = repeat(torch.randn(batch_size, nheads, device=device, - dtype=itype), - "b h -> b h p", - p=headdim) - dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, - "h -> h p", - p=headdim) - A = repeat(-torch.rand(nheads, device=device) - 1.0, - "h -> h p n", - p=headdim, - n=dstate) + dt = repeat( + torch.randn(batch_size, nheads, device=device, dtype=itype), + "b h -> b h p", + p=headdim, + ) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) + A = repeat( + -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate + ) D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) B = torch.randn(batch_size, ngroups, dstate, device=device) C = torch.randn(batch_size, ngroups, dstate, device=device) z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].detach().clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices, - pad_slot_id=PAD_SLOT_ID) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + out = selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 6a3f21ba543f..a71057977087 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,11 +6,11 @@ import torch.nn.functional as F from einops import rearrange, repeat -from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) +from vllm.model_executor.layers.mamba.ops.ssd_combined import mamba_chunk_scan_combined from vllm.platforms import current_platform from vllm.v1.attention.backends.mamba_attn import ( - _query_start_loc_to_chunk_indices_offsets) + _query_start_loc_to_chunk_indices_offsets, +) # Added by the IBM Team, 2024 @@ -22,12 +22,10 @@ def segsum(x): """Calculates segment sum.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=-1) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=0) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum @@ -46,8 +44,9 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks - X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) - for x in (X, A, B, C)) + X, A, B, C = ( + rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C) + ) A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) @@ -74,7 +73,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) # Add output of intra-chunk and inter-chunk terms # (diagonal and off-diagonal blocks) @@ -82,61 +81,48 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): return Y, final_state -def generate_random_inputs(batch_size, - seqlen, - n_heads, - d_head, - itype, - device='cuda'): - +def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"): current_platform.seed_everything(0) - A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device)) dt = F.softplus( - torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - - 4) - X = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - B = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - C = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4 + ) + X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) return A, dt, X, B, C -def generate_continuous_batched_examples(example_lens_by_batch, - num_examples, - full_length, - last_taken, - exhausted, - n_heads, - d_head, - itype, - device='cuda'): - +def generate_continuous_batched_examples( + example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device="cuda", +): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed # them in continuous batches to the kernels # generate the full-length example - A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs( + num_examples, full_length, n_heads, d_head, itype + ) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4 + ) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch # e.g., example_lens=(8, 4) means take 8 samples from first eg, # 4 examples from second eg, etc def get_continuous_batch(example_lens: tuple[int, ...]): - indices = [] for i, x in enumerate(example_lens): c = last_taken.get(i, 0) @@ -144,8 +130,10 @@ def get_continuous_batch(example_lens: tuple[int, ...]): last_taken[i] = (c + x) % full_length exhausted[i] = last_taken[i] == 0 - return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) - ]).unsqueeze(0) for x in (dt, X, B, C)) + return ( + torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0) + for x in (dt, X, B, C) + ) # internal function that maps "n" to the appropriate right boundary # value when forming continuous batches from examples of length given @@ -157,19 +145,20 @@ def end_boundary(n: int): IND_E = None for spec in example_lens_by_batch: - # get the (maybe partial) example seen in this cont batch dt2, X2, B2, C2 = get_continuous_batch(spec) # get the metadata - cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - seq_idx = torch.zeros(cu_seqlens[-1], - dtype=torch.int32, - device=cu_seqlens.device) - for i, (srt, end) in enumerate(zip( + cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0) + seq_idx = torch.zeros( + cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device + ) + for i, (srt, end) in enumerate( + zip( cu_seqlens, cu_seqlens[1:], - )): + ) + ): seq_idx[srt:end] = i # for cont batch @@ -179,18 +168,19 @@ def end_boundary(n: int): IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], - cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + yield ( + [Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)], + cu_seqlens, + seq_idx.unsqueeze(0), + (A, dt2, X2, B2, C2), + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) @pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) -def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, - itype): - +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): # this tests the kernels on a single example (no batching) # set seed @@ -200,30 +190,27 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # it is not an operational limitation. seqlen, chunk_size = seq_len_chunk_size - A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, - B, C, chunk_size) + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, chunk_size + ) - Y, final_state = mamba_chunk_scan_combined(X, - dt, - A, - B, - C, - chunk_size, - D=None, - return_final_states=True) + Y, final_state = mamba_chunk_scan_combined( + X, dt, A, B, C, chunk_size, D=None, return_final_states=True + ) # just test the last in sequence torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.allclose(final_state[:, -1], - final_state_min[:, -1].to(torch.float32), - atol=1e-3, - rtol=1e-3) + torch.allclose( + final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-3, + rtol=1e-3, + ) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -232,32 +219,39 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ - # small-ish chunk_size (8) (64, 8, 2, [(64, 32), (64, 32)]), (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary - (64, 8, 2, [(4, 4), (4, 4), (4, 4), - (4, 4)]), # chunk_size larger than cont batches - (64, 8, 5, [ - (64, 32, 16, 8, 8), - (8, 16, 32, 16, 8), - (8, 8, 16, 32, 16), - ]), # mode examples with varied lengths - + ( + 64, + 8, + 2, + [(4, 4), (4, 4), (4, 4), (4, 4)], + ), # chunk_size larger than cont batches + ( + 64, + 8, + 5, + [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ], + ), # mode examples with varied lengths # odd chunk_size - (64, 29, 2, [(11, 4), (13, 23), (19, 22), - (21, 15)]), # irregular sizes - + (64, 29, 2, [(11, 4), (13, 23), (19, 22), (21, 15)]), # irregular sizes # large-ish chunk_size (256) - (64, 256, 1, [(5, ), (1, ), (1, ), - (1, )]), # irregular sizes with small sequences - (64, 256, 2, [(5, 30), (1, 2), (1, 2), - (1, 2)]), # irregular sizes with small sequences - ]) -def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, - itype): - + (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences + ( + 64, + 256, + 2, + [(5, 30), (1, 2), (1, 2), (1, 2)], + ), # irregular sizes with small sequences + ], +) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): # this test with multiple examples in a continuous batch # (i.e. chunked prefill) @@ -270,13 +264,17 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, states = None for Y_min, cu_seqlens, seq_idx, ( - A, dt, X, B, C) in generate_continuous_batched_examples( - cases, num_examples, seqlen, last_taken, exhausted, n_heads, - d_head, itype): - - chunk_indices, chunk_offsets = \ - _query_start_loc_to_chunk_indices_offsets( - cu_seqlens, chunk_size, cu_seqlens[-1]) + A, + dt, + X, + B, + C, + ) in generate_continuous_batched_examples( + cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype + ): + chunk_indices, chunk_offsets = _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1] + ) Y, new_states = mamba_chunk_scan_combined( X, @@ -296,9 +294,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test the last in sequence for i in range(num_examples): - # just test one dim and dstate - Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_eg = Y[0, cu_seqlens[i] : cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) @@ -306,5 +303,5 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, states = new_states for i, clear in exhausted.items(): if clear: - states[i].fill_(0.) + states[i].fill_(0.0) exhausted[i] = False diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py index b95d87cd04f5..d46847fbf6a3 100644 --- a/tests/kernels/moe/modular_kernel_tools/cli_args.py +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -9,18 +9,19 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from .common import Config -from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +from .mk_objects import ( + MK_ALL_PREPARE_FINALIZE_TYPES, + MK_FUSED_EXPERT_TYPES, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, +) def make_config_arg_parser(description: str): - def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: for pf in MK_ALL_PREPARE_FINALIZE_TYPES: if pf.__name__ == s: return pf - raise ValueError( - f"Cannot find a PrepareFinalize type that matches {s}") + raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}") def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: for fe in MK_FUSED_EXPERT_TYPES: @@ -45,15 +46,18 @@ def to_quant_torch_dtype(s: str) -> torch.dtype: "--pf-type", type=to_pf_class_type, required=True, - help=("Choose a PrepareFinalize Type : " - f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"), + help=( + "Choose a PrepareFinalize Type : " + f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}" + ), ) parser.add_argument( "--experts-type", type=to_experts_class_type, required=True, - help=(f"Choose a FusedExpert type : " - f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"), + help=( + f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}" + ), ) parser.add_argument( "-m", @@ -74,66 +78,65 @@ def to_quant_torch_dtype(s: str) -> torch.dtype: default=1024, help="N dimension of the first fused-moe matmul", ) - parser.add_argument("--num-experts", - type=int, - default=32, - help="Global num experts") - parser.add_argument("--topk", - nargs="+", - type=int, - default=[4, 1], - help="num topk") + parser.add_argument( + "--num-experts", type=int, default=32, help="Global num experts" + ) + parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk") parser.add_argument( "--fused-moe-chunk-size", type=int, - help="Fused moe chunk size used for the non-batched fused experts impl." + help="Fused moe chunk size used for the non-batched fused experts impl.", ) # Quant args - parser.add_argument("--quant-dtype", - type=to_quant_torch_dtype, - help="Quant datatype") - parser.add_argument("--per-token-quantized-activations", - action='store_true', - help=("The input activations must be per-token " - "quantized")) - parser.add_argument("--per-channel-quantized-weights", - action="store_true", - help="The weights must be per-channel quantized.") - parser.add_argument("--block-shape", - nargs="+", - type=int, - help="Quantization block shape") + parser.add_argument( + "--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype" + ) + parser.add_argument( + "--per-token-quantized-activations", + action="store_true", + help=("The input activations must be per-token quantized"), + ) + parser.add_argument( + "--per-channel-quantized-weights", + action="store_true", + help="The weights must be per-channel quantized.", + ) + parser.add_argument( + "--block-shape", nargs="+", type=int, help="Quantization block shape" + ) # Torch trace profile generation args - parser.add_argument("--torch-trace-dir-path", - type=str, - default=None, - help="Get torch trace for single execution") + parser.add_argument( + "--torch-trace-dir-path", + type=str, + default=None, + help="Get torch trace for single execution", + ) return parser def _validate_args(args: argparse.Namespace): - if args.quant_dtype is not None: assert args.quant_dtype == torch.float8_e4m3fn if args.block_shape is not None: assert len(args.block_shape) == 2, ( - f"block shape must have 2 elements. got {args.block_shape}") + f"block shape must have 2 elements. got {args.block_shape}" + ) if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: - assert args.world_size == 1, ( - "Single GPU objects need world size set to 1") + assert args.world_size == 1, "Single GPU objects need world size set to 1" if args.torch_trace_dir_path is not None: from pathlib import Path + assert Path(args.torch_trace_dir_path).is_dir(), ( - f"Please create {args.torch_trace_dir_path}") + f"Please create {args.torch_trace_dir_path}" + ) def make_config(args: argparse.Namespace) -> Config: - _validate_args(args) quant_config = None @@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config: quant_dtype=args.quant_dtype, per_act_token_quant=args.per_token_quantized_activations, per_out_ch_quant=args.per_channel_quantized_weights, - block_shape=args.block_shape) + block_shape=args.block_shape, + ) return Config( Ms=args.m, @@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config: fused_experts_type=args.experts_type, fused_moe_chunk_size=args.fused_moe_chunk_size, world_size=args.world_size, - torch_trace_dir_path=args.torch_trace_dir_path) + torch_trace_dir_path=args.torch_trace_dir_path, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index fd99e8dc5c98..48d4beb8c294 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -10,38 +10,54 @@ from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size + # Fused experts and PrepareFinalize imports from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) + BatchedTritonExperts, + NaiveBatchedExperts, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) +from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from .parallel_utils import ProcessGroupInfo -from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, - make_quant_fp8_weights, per_token_cast_to_fp8) +from .utils import ( + make_block_quant_fp8_weights, + make_non_quant_weights, + make_quant_fp8_weights, + per_token_cast_to_fp8, +) if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) + DeepEPHTPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + DeepEPLLPrepareAndFinalize, + ) def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: @@ -110,8 +126,7 @@ def is_per_act_token_quant(self) -> bool: def is_per_tensor_act_quant(self) -> bool: if self.quant_config is None: return False - return (not self.is_per_act_token_quant - and self.quant_block_shape is None) + return not self.is_per_act_token_quant and self.quant_block_shape is None @property def is_per_out_ch_quant(self) -> bool: @@ -136,7 +151,8 @@ def topk_ids_dtype(self) -> Optional[torch.dtype]: if self.prepare_finalize_type == PplxPrepareAndFinalize: topk_ids_dtype = torch.uint32 elif self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize, ]: topk_ids_dtype = torch.int64 return topk_ids_dtype @@ -147,7 +163,7 @@ def num_local_experts(self) -> int: def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: """ - make env data for vllm launch. + make env data for vllm launch. """ vllm_config = VllmConfig() vllm_config.parallel_config.data_parallel_size = self.world_size @@ -159,34 +175,45 @@ def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: } if self.fused_moe_chunk_size is not None: env_dict.update( - {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)} + ) return vllm_config, env_dict def is_fp8_block_quantized(self): - return (self.quant_dtype == torch.float8_e4m3fn - and self.quant_block_shape is not None) + return ( + self.quant_dtype == torch.float8_e4m3fn + and self.quant_block_shape is not None + ) def is_batched_prepare_finalize(self): return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize + PplxPrepareAndFinalize, + DeepEPLLPrepareAndFinalize, ] def is_batched_fused_experts(self): return self.fused_experts_type in [ - CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, - NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts + CutlassExpertsFp8, + BatchedDeepGemmExperts, + BatchedTritonExperts, + NaiveBatchedExperts, + BatchedTritonOrDeepGemmExperts, ] def is_standard_fused_experts(self): return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts + CutlassExpertsFp8, + DeepGemmExperts, + TritonOrDeepGemmExperts, + TritonExperts, ] def is_fe_16bit_supported(self): return self.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - NaiveBatchedExperts, TritonExperts + BatchedTritonExperts, + BatchedTritonOrDeepGemmExperts, + NaiveBatchedExperts, + TritonExperts, ] def is_fe_fp8_supported(self): @@ -214,8 +241,10 @@ def is_fe_block_fp8_supported(self): def is_fe_supports_chunking(self): return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts + CutlassExpertsFp8, + DeepGemmExperts, + TritonOrDeepGemmExperts, + TritonExperts, ] def needs_deep_gemm(self): @@ -229,7 +258,8 @@ def needs_pplx(self): def needs_deep_ep(self): return self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize, ] def all2all_backend(self): @@ -243,8 +273,9 @@ def all2all_backend(self): def needs_all2all(self): return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize + PplxPrepareAndFinalize, + DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize, ] def is_valid(self): @@ -261,14 +292,16 @@ def is_valid(self): return False # Check quantization sanity - if (int(self.is_per_act_token_quant) + - int(self.is_per_tensor_act_quant) + - int(self.quant_block_shape is not None)) > 1: + if ( + int(self.is_per_act_token_quant) + + int(self.is_per_tensor_act_quant) + + int(self.quant_block_shape is not None) + ) > 1: # invalid quant config return False # check bf16 / fp16 support - is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) + is_16bit = self.dtype.itemsize == 2 and self.quant_dtype is None if is_16bit and not self.is_fe_16bit_supported(): return False @@ -309,10 +342,10 @@ class WeightTensors: def describe(self): s = "" s += "== Weight Tensors: \n" - s += f' - {_describe_tensor(self.w1, "w1")} \n' - s += f' - {_describe_tensor(self.w2, "w2")} \n' - s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' - s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + s += f" - {_describe_tensor(self.w1, 'w1')} \n" + s += f" - {_describe_tensor(self.w2, 'w2')} \n" + s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n" + s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n" return s def to_current_device(self): @@ -322,13 +355,10 @@ def to_current_device(self): if is_quantized: assert self.w1_scale is not None assert self.w2_scale is not None - self.w1_scale = self.w1_scale.to( - device=torch.cuda.current_device()) - self.w2_scale = self.w2_scale.to( - device=torch.cuda.current_device()) + self.w1_scale = self.w1_scale.to(device=torch.cuda.current_device()) + self.w2_scale = self.w2_scale.to(device=torch.cuda.current_device()) - def slice_weights(self, rank: int, - num_local_experts: int) -> "WeightTensors": + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] @@ -344,13 +374,11 @@ def slice_weights(self, rank: int, @staticmethod def make(config: Config) -> "WeightTensors": - if config.quant_dtype is None: # just make normal dtype weights - w1, w2 = make_non_quant_weights(e=config.E, - n=config.N, - k=config.K, - dtype=config.dtype) + w1, w2 = make_non_quant_weights( + e=config.E, n=config.N, k=config.K, dtype=config.dtype + ) return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) assert config.quant_dtype == torch.float8_e4m3fn @@ -361,10 +389,7 @@ def make(config: Config) -> "WeightTensors": k=config.K, per_out_channel_quant=config.is_per_out_ch_quant, ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale) + return WeightTensors(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale) assert config.quant_block_shape is not None w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( @@ -373,10 +398,7 @@ def make(config: Config) -> "WeightTensors": k=config.K, block_size=config.quant_block_shape, ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale) + return WeightTensors(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale) @dataclass @@ -393,22 +415,22 @@ class RankTensors: def describe(self): s = "" s += "== Rank Tensors: \n" - s += f' - {_describe_tensor(self.hidden_states, "HS")} \n' - s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n' - s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n' - s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n' - s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n' + s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n" + s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n" + s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n" + s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n" + s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n" return s @staticmethod def make_hidden_states( - config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + config: Config, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Return hidden_states """ m, k, dtype = (config.M, config.K, config.dtype) - a = (torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0) + a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0 if config.quant_dtype is None: return a, None @@ -419,36 +441,29 @@ def make_hidden_states( # first - so further quantize and dequantize will yield the same # values. if config.is_per_tensor_act_quant: - a_q, a_scales = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=False) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False) return a_q.float().mul(a_scales).to(dtype), a_scales if config.is_per_act_token_quant: - a_q, a_scales = ops.scaled_fp8_quant(a, - use_per_token_if_dynamic=True) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True) return a_q.float().mul(a_scales).to(dtype), None assert config.quant_block_shape is not None block_k = config.quant_block_shape[1] a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k) - return a_q.float().view( - (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None + return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to( + dtype + ), None @staticmethod def make(config: Config, pgi: ProcessGroupInfo): - dtype = config.dtype topk, m, _ = (config.topk, config.M, config.K) - hidden_states, hidden_states_scale = RankTensors.make_hidden_states( - config) - - num_local_experts, global_num_experts = (config.num_local_experts, - config.E) - score = torch.randn((m, global_num_experts), - device="cuda", - dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, - False) + hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config) + + num_local_experts, global_num_experts = (config.num_local_experts, config.E) + score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) topk_ids = topk_ids.to(config.topk_ids_dtype) # distribute topk_ids evenly @@ -458,14 +473,15 @@ def make(config: Config, pgi: ProcessGroupInfo): expert_map = None if config.world_size > 1: - expert_map = torch.full((global_num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full( + (global_num_experts,), fill_value=-1, dtype=torch.int32 + ) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - expert_map = expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + expert_map = expert_map.to( + device=torch.cuda.current_device(), dtype=torch.int32 + ) return RankTensors( hidden_states=hidden_states, @@ -477,29 +493,30 @@ def make(config: Config, pgi: ProcessGroupInfo): ) -def reference_moe_impl(config: Config, weights: WeightTensors, - rank_tensors: RankTensors) -> torch.Tensor: - - return torch_experts(a=rank_tensors.hidden_states, - w1=weights.w1, - w2=weights.w2, - topk_weight=rank_tensors.topk_weights, - topk_ids=rank_tensors.topk_ids, - global_num_experts=config.E, - expert_map=None, - w1_scale=weights.w1_scale, - w2_scale=weights.w2_scale, - a1_scale=rank_tensors.hidden_states_scale, - quant_dtype=config.quant_dtype, - per_act_token_quant=config.is_per_act_token_quant, - block_shape=config.quant_block_shape, - apply_router_weights_on_input=config.topk == 1) +def reference_moe_impl( + config: Config, weights: WeightTensors, rank_tensors: RankTensors +) -> torch.Tensor: + return torch_experts( + a=rank_tensors.hidden_states, + w1=weights.w1, + w2=weights.w2, + topk_weight=rank_tensors.topk_weights, + topk_ids=rank_tensors.topk_ids, + global_num_experts=config.E, + expert_map=None, + w1_scale=weights.w1_scale, + w2_scale=weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + quant_dtype=config.quant_dtype, + per_act_token_quant=config.is_per_act_token_quant, + block_shape=config.quant_block_shape, + apply_router_weights_on_input=config.topk == 1, + ) def make_fused_experts( - config: Config, moe: FusedMoEConfig, - num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: - + config: Config, moe: FusedMoEConfig, num_dispatchers: int +) -> mk.FusedMoEPermuteExpertsUnpermute: use_fp8 = config.quant_dtype == torch.float8_e4m3fn batch_kwargs = { "max_num_tokens": moe.max_num_tokens, @@ -547,8 +564,7 @@ def make_fused_experts( experts = NaiveBatchedExperts(**kwargs) elif config.fused_experts_type == CutlassExpertsFp8: use_batched_format = config.is_batched_prepare_finalize() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) + num_experts = moe.num_local_experts if use_batched_format else moe.num_experts kwargs = { "max_experts_per_worker": num_experts, "out_dtype": moe.in_dtype, @@ -556,7 +572,7 @@ def make_fused_experts( "per_out_ch_quant": config.is_per_out_ch_quant, "block_shape": config.quant_block_shape, "num_dispatchers": num_dispatchers, - "use_batched_format": use_batched_format + "use_batched_format": use_batched_format, } print(f"Making CutlassExpertsFp8 {kwargs} ...") experts = CutlassExpertsFp8(**kwargs) @@ -564,14 +580,15 @@ def make_fused_experts( return experts -def make_modular_kernel(config: Config, - vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: - +def make_modular_kernel( + config: Config, vllm_config: VllmConfig +) -> mk.FusedMoEModularKernel: def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( @@ -598,11 +615,11 @@ def next_power_of_2(x): else: prepare_finalize = MoEPrepareAndFinalizeNoEP() - fused_experts = make_fused_experts(config, moe, - prepare_finalize.num_dispatchers()) + fused_experts = make_fused_experts(config, moe, prepare_finalize.num_dispatchers()) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts) + prepare_finalize=prepare_finalize, fused_experts=fused_experts + ) return modular_kernel @@ -623,8 +640,7 @@ def run_modular_kernel( mk = make_modular_kernel(config, vllm_config) mk_kwargs = { - "hidden_states": rank_tensors.hidden_states.clone( - ), # impls might update the tensor in place + "hidden_states": rank_tensors.hidden_states.clone(), # impls might update the tensor in place "w1": rank_weights.w1, "w2": rank_weights.w2, "topk_weights": rank_tensors.topk_weights, diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 5dbfdfc153f9..6aa64d3f4929 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -13,10 +13,18 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.platforms import current_platform -from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, - run_modular_kernel) -from .mk_objects import (MK_FUSED_EXPERT_TYPES, - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS) +from .common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) +from .mk_objects import ( + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, +) from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config @@ -37,8 +45,9 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -59,8 +68,7 @@ def rank_worker( rank_tensors = RankTensors.make(cfgx, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) with set_current_vllm_config(vllm_config): ref_out = reference_moe_impl(cfgx, weights, rank_tensors) @@ -69,28 +77,27 @@ def rank_worker( def make_feature_matrix(csv_file_path: str): - from dataclasses import asdict import pandas as pd - def add_to_results(config: Config, - success: Result, - results_df: Optional[pd.DataFrame] = None): + def add_to_results( + config: Config, success: Result, results_df: Optional[pd.DataFrame] = None + ): config_dict = asdict(config) - config_dict['prepare_finalize_type'] = config_dict[ - 'prepare_finalize_type'].__name__ - config_dict['fused_experts_type'] = config_dict[ - 'fused_experts_type'].__name__ - config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant - quant_config_dict = config_dict['quant_config'] - del config_dict['quant_config'] + config_dict["prepare_finalize_type"] = config_dict[ + "prepare_finalize_type" + ].__name__ + config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__ + config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant + quant_config_dict = config_dict["quant_config"] + del config_dict["quant_config"] if quant_config_dict is None: quant_config = FusedMoEQuantConfig(None) quant_config_dict = asdict(quant_config) config_dict |= quant_config_dict - result_dict = config_dict | {'success': success.name} + result_dict = config_dict | {"success": success.name} result_df = pd.DataFrame([result_dict]) if results_df is None: @@ -111,22 +118,26 @@ def add_to_results(config: Config, Q_TYPES = MK_QUANT_CONFIGS combinations = list( - product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)) + product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES) + ) results_df: Optional[pd.DataFrame] = None for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( - combinations): #noqa: E501 - config = Config(Ms=[m], - K=k, - N=n, - E=e, - topks=topks, - dtype=dtype, - prepare_finalize_type=pf_type, - fused_experts_type=experts_type, - quant_config=quant_config, - world_size=2, - fused_moe_chunk_size=None) + combinations + ): # noqa: E501 + config = Config( + Ms=[m], + K=k, + N=n, + E=e, + topks=topks, + dtype=dtype, + prepare_finalize_type=pf_type, + fused_experts_type=experts_type, + quant_config=quant_config, + world_size=2, + fused_moe_chunk_size=None, + ) success = None if config.is_valid(): @@ -134,9 +145,14 @@ def add_to_results(config: Config, try: weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, - vllm_config, env_dict, config, - weights) + parallel_launch_with_config( + config.world_size, + rank_worker, + vllm_config, + env_dict, + config, + weights, + ) success = Result.PASS except Exception as _: success = Result.FAIL @@ -149,25 +165,33 @@ def add_to_results(config: Config, results_df.to_csv(f"{csv_file_path}") -if __name__ == '__main__': +if __name__ == "__main__": import argparse from pathlib import Path - parser = argparse.ArgumentParser(description=( - "Make ModularKernel feature matrix \n" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501 - "-f ./feature_matrices/feature_matrix.csv")) - - parser.add_argument("-f", - "--feature-matrix-csv-file-path", - type=str, - required=True, - help="File name to Generate a .csv file") + + parser = argparse.ArgumentParser( + description=( + "Make ModularKernel feature matrix \n" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501 + "-f ./feature_matrices/feature_matrix.csv" + ) + ) + + parser.add_argument( + "-f", + "--feature-matrix-csv-file-path", + type=str, + required=True, + help="File name to Generate a .csv file", + ) args = parser.parse_args() csv_path = args.feature_matrix_csv_file_path - assert csv_path.endswith( - 'csv'), f"Need a file path ending with .csv, got {csv_path}" - assert Path(csv_path).parent.is_dir( - ), f"Cannot find parent directory for {Path(csv_path).parent}" + assert csv_path.endswith("csv"), ( + f"Need a file path ending with .csv, got {csv_path}" + ) + assert Path(csv_path).parent.is_dir(), ( + f"Cannot find parent directory for {Path(csv_path).parent}" + ) make_feature_matrix(args.feature_matrix_csv_file_path) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 73214066f7ea..2ae28baac6f8 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -5,43 +5,54 @@ # Fused experts and PrepareFinalize imports from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) + BatchedTritonExperts, + NaiveBatchedExperts, +) from vllm.model_executor.layers.fused_moe.layer import TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.utils import has_deep_ep, has_pplx if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) + DeepEPHTPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + DeepEPLLPrepareAndFinalize, + ) if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] if has_pplx(): MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] if has_deep_ep(): MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize, ] MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] -MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +MK_ALL_PREPARE_FINALIZE_TYPES = ( + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES +) MK_FUSED_EXPERT_TYPES = [ BatchedDeepGemmExperts, @@ -57,30 +68,40 @@ MK_QUANT_CONFIGS = [ None, # per-channel / per-column weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=False, - block_shape=None), + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None, + ), # per-channel / per-column weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=True, - block_shape=None), + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None, + ), # per-tensor weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None, + ), # per-tensor weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=True, - block_shape=None), + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None, + ), # block-quantized weights and 128 block per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=[128, 128]), + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128], + ), # TODO (varun) : Should we test the following combinations ? # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 1f8d21a7a702..e05de1aa0231 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -6,13 +6,11 @@ from typing import Any, Callable, Optional import torch -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.config import VllmConfig, set_current_vllm_config -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed import init_distributed_environment, initialize_model_parallel from vllm.utils import get_open_port ## Parallel Processes Utils @@ -30,10 +28,11 @@ class ProcessGroupInfo: device: torch.device -def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, - local_rank: int): - +def _set_vllm_config( + vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int +): import tempfile + temp_file = tempfile.mkstemp()[1] set_current_vllm_config(vllm_config) @@ -47,13 +46,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, ) initialize_model_parallel( - tensor_model_parallel_size=vllm_config.parallel_config. - tensor_parallel_size, - pipeline_model_parallel_size=vllm_config.parallel_config. - pipeline_parallel_size, + tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, ) - cpu_group = torch.distributed.new_group(list(range(world_size)), - backend="gloo") + cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") return cpu_group @@ -63,8 +59,7 @@ def _worker_parallel_launch( world_local_size: int, node_rank: int, init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, - P], None], + worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None], vllm_config: Optional[VllmConfig], env_dict: Optional[dict], *args: P.args, @@ -132,7 +127,8 @@ def parallel_launch_with_config( worker, vllm_config, env_dict, - ) + args, + ) + + args, nprocs=world_size, join=True, ) diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index dd16ffb2eabe..0de1762f6425 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -14,28 +14,31 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config -def do_profile(fn: Callable, - fn_kwargs: dict[Any, Any], - pgi: ProcessGroupInfo, - config: Config, - num_warmups: int = 5): +def do_profile( + fn: Callable, + fn_kwargs: dict[Any, Any], + pgi: ProcessGroupInfo, + config: Config, + num_warmups: int = 5, +): for _ in range(num_warmups): fn(**fn_kwargs) with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - record_shapes=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=True, ) as tprof: fn(**fn_kwargs) torch.cuda.synchronize(torch.cuda.current_device()) # TODO (varun): Add a descriptive trace file name tprof.export_chrome_trace( - f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json") + f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json" + ) def profile_modular_kernel( @@ -82,8 +85,9 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -108,20 +112,25 @@ def rank_worker( def run(config: Config): weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights + ) -if __name__ == '__main__': +if __name__ == "__main__": from .cli_args import make_config, make_config_arg_parser - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() assert args.torch_trace_dir_path is not None, ( - "Please pass in a directory to store torch traces") + "Please pass in a directory to store torch traces" + ) config = make_config(args) run(config) diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py index 09bb4a34f318..2a465fd9b4c3 100644 --- a/tests/kernels/moe/modular_kernel_tools/utils.py +++ b/tests/kernels/moe/modular_kernel_tools/utils.py @@ -8,12 +8,12 @@ def per_token_cast_to_fp8( - x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: int +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, block_size) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) @@ -21,8 +21,8 @@ def per_token_cast_to_fp8( def per_block_cast_to_fp8( - x: torch.Tensor, block_size_k: int, - block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size_k: int, block_size_n: int +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( @@ -34,8 +34,9 @@ def per_block_cast_to_fp8( device=x.device, ) x_padded[:m, :n] = x - x_view = x_padded.view(-1, block_size_k, - x_padded.size(1) // block_size_k, block_size_n) + x_view = x_padded.view( + -1, block_size_k, x_padded.size(1) // block_size_k, block_size_n + ) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() @@ -86,24 +87,23 @@ def make_block_quant_fp8_weights( w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), - device=device, - dtype=torch.float32) - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), - device=device, - dtype=torch.float32) + w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), device=device, dtype=torch.float32) + w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), device=device, dtype=torch.float32) - assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, - (k + (block_k - 1)) // block_k) + assert w1_s.shape == ( + e, + (2 * n + (block_n - 1)) // block_n, + (k + (block_k - 1)) // block_k, + ) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size_k=block_k, - block_size_n=block_n) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size_k=block_k, - block_size_n=block_n) + w1[i], w1_s[i] = per_block_cast_to_fp8( + w1_bf16[i], block_size_k=block_k, block_size_n=block_n + ) + w2[i], w2_s[i] = per_block_cast_to_fp8( + w2_bf16[i], block_size_k=block_k, block_size_n=block_n + ) return w1, w2, w1_s, w2_s @@ -127,16 +127,14 @@ def make_quant_fp8_weights( n_b_scales = 2 * n if per_out_channel_quant else 1 k_b_scales = k if per_out_channel_quant else 1 - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_channel_quant) + w1[expert], use_per_token_if_dynamic=per_out_channel_quant + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_channel_quant) + w2[expert], use_per_token_if_dynamic=per_out_channel_quant + ) return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index 1ad361ae0733..9d087ad13b82 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -3,6 +3,7 @@ """ DeepEP test utilities """ + import dataclasses import os import traceback @@ -10,17 +11,18 @@ import torch from torch.distributed import ProcessGroup -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.utils import get_open_port, has_deep_ep if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) + DeepEPHTPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + DeepEPLLPrepareAndFinalize, + ) ## Parallel Processes Utils @@ -96,7 +98,8 @@ def parallel_launch( 0, f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", worker, - ) + args, + ) + + args, nprocs=world_size, join=True, ) @@ -118,48 +121,57 @@ class DeepEPLLArgs: use_fp8_dispatch: bool -def make_deepep_ht_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - ht_args: DeepEPHTArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - +def make_deepep_ht_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # high throughput a2a num_nvl_bytes = 1024 * 1024 * 1024 # 1GB num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 - buffer = deep_ep.Buffer(group=pg, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=num_qps_per_rank) - return DeepEPHTPrepareAndFinalize(buffer=buffer, - num_dispatchers=pgi.world_size, - dp_size=dp_size, - rank_expert_offset=pgi.rank * - ht_args.num_local_experts) - - -def make_deepep_ll_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - deepep_ll_args: DeepEPLLArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + buffer = deep_ep.Buffer( + group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank, + ) + return DeepEPHTPrepareAndFinalize( + buffer=buffer, + num_dispatchers=pgi.world_size, + dp_size=dp_size, + rank_expert_offset=pgi.rank * ht_args.num_local_experts, + ) + +def make_deepep_ll_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # low-latency a2a num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, - pgi.world_size, deepep_ll_args.num_experts) + deepep_ll_args.max_tokens_per_rank, + deepep_ll_args.hidden_size, + pgi.world_size, + deepep_ll_args.num_experts, + ) - buffer = deep_ep.Buffer(group=pg, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=deepep_ll_args.num_experts // - pgi.world_size) + buffer = deep_ep.Buffer( + group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size, + ) return DeepEPLLPrepareAndFinalize( buffer=buffer, @@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup, ) -def make_deepep_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - deepep_ht_args: Optional[DeepEPHTArgs], - deepep_ll_args: Optional[DeepEPLLArgs], - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): +def make_deepep_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): if deepep_ht_args is not None: assert deepep_ll_args is None - return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, - block_shape) + return make_deepep_ht_a2a( + pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape + ) assert deepep_ll_args is not None return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 69317405d48b..c62de471a603 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -7,14 +7,18 @@ import pytest import torch -from tests.kernels.moe.utils import (batched_moe, - make_quantized_test_activations, - make_test_weights, naive_batched_moe) +from tests.kernels.moe.utils import ( + batched_moe, + make_quantized_test_activations, + make_test_weights, + naive_batched_moe, +) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel) + invoke_moe_batched_triton_kernel, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform from vllm.triton_utils import tl @@ -68,41 +72,54 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - A = torch.randn( - (config.num_experts, config.max_tokens_per_expert, config.K), + A = ( + torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.in_dtype, + ) + / 10 + ) + B = torch.randn( + (config.num_experts, config.N, config.K), device="cuda", - dtype=config.in_dtype) / 10 - B = torch.randn((config.num_experts, config.N, config.K), - device="cuda", - dtype=config.in_dtype) + dtype=config.in_dtype, + ) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.out_dtype) + dtype=config.out_dtype, + ) - num_expert_tokens = torch.randint(low=0, - high=config.max_tokens_per_expert, - size=(config.num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts,), + device="cuda", + dtype=torch.int32, + ) return BatchedMMTensors(A, B, C, num_expert_tokens) @pytest.mark.parametrize("num_experts", [8, 16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", - [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 1024]) @pytest.mark.parametrize( - "dtype", - [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) + "dtype", [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16] +) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) -def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, - N: int, dtype: torch.dtype, - block_shape: Optional[list[int]], - per_act_token_quant: bool): +def test_batched_mm( + num_experts: int, + max_tokens_per_expert: int, + K: int, + N: int, + dtype: torch.dtype, + block_shape: Optional[list[int]], + per_act_token_quant: bool, +): current_platform.seed_everything(7) use_fp8_w8a8 = dtype == torch.float8_e4m3fn @@ -120,11 +137,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, act_dtype = dtype quant_dtype = None - num_expert_tokens = torch.randint(low=0, - high=max_tokens_per_expert, - size=(num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=max_tokens_per_expert, + size=(num_experts,), + device="cuda", + dtype=torch.int32, + ) A, A_q, A_scale = make_quantized_test_activations( num_experts, @@ -154,7 +173,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, compute_tl_dtype = { torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, - torch.float32: tl.float32 + torch.float32: tl.float32, }[test_output.dtype] assert A_q.dtype == B_q.dtype @@ -176,7 +195,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, config={ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 + "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32, }, per_act_token_quant=per_act_token_quant, block_shape=block_shape, @@ -189,11 +208,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, num_expert_tokens, ) - q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, - num_expert_tokens, - A_scale, B_scale, - block_shape, - per_act_token_quant) + q_ref_output = native_batched_masked_quant_matmul( + A_q, + B_q, + q_ref_output, + num_expert_tokens, + A_scale, + B_scale, + block_shape, + per_act_token_quant, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -311,12 +335,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(batched_output, - baseline_output, - atol=3e-2, - rtol=2e-2) + torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) - torch.testing.assert_close(triton_output, - batched_output, - atol=2e-2, - rtol=2e-2) + torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 7dc6282326b6..9143103c1839 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -5,15 +5,21 @@ import torch from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) + _valid_deep_gemm_shape, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used @@ -24,8 +30,7 @@ from deep_gemm import get_m_alignment_for_contiguous_layout if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -97,8 +102,7 @@ SEEDS = [0] -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, - block_shape): +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape topk = topk_ids.size(1) @@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) # Skip all tests if CUDA is not available @@ -149,8 +147,9 @@ def setup_cuda(): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, - monkeypatch): +def test_w8a8_block_fp8_fused_moe( + M, N, K, E, topk, block_size, dtype, seed, monkeypatch +): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") @@ -161,20 +160,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) - - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=False, - block_shape=block_size) + _, w1, w1_s, _, w2, w2_s = make_test_weights( + E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size, + ) + + m_fused_moe = modular_triton_fused_moe( + use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_act_token_quant=False, + block_shape=block_size, + ) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) @@ -226,8 +229,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, - monkeypatch): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") @@ -246,49 +248,53 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + _, w1, w1_s, _, w2, w2_s = make_test_weights( + E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size, + ) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = ( + chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike() + ) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids, block_size) + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size + ) if use_compile: - deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, - backend="inductor", - fullgraph=True) + deep_gemm_moe_fp8_fn = torch.compile( + deep_gemm_moe_fp8, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(topk_weights, 0) torch._dynamo.mark_dynamic(topk_ids, 0) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) if use_cudagraph: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 8e680c722935..078d110b1925 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -5,16 +5,17 @@ import torch from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quant_utils import (native_per_token_group_quant_int8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_int8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_int8( - act_out, block_k) + act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k) act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -118,13 +113,9 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + _, w1, w1_s, _, w2, w2_s = make_test_weights( + E, N, K, dtype, torch.int8, per_act_token_quant=False, block_shape=block_size + ) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): @@ -140,8 +131,9 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): w2_scale=w2_s, block_shape=block_size, ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + ref_out = torch_w8a8_block_int8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) # Check results torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065) diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 0872836b6064..cf648dc36d61 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -15,7 +15,6 @@ @dataclasses.dataclass class TestTensors: - topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] = None @@ -25,32 +24,31 @@ def to_device(self, device: str): self.expert_map = self.expert_map.to(device=device) @staticmethod - def make(num_tokens: int, num_topk: int, num_experts: int, device: str, - topk_ids_dtype: torch.dtype) -> "TestTensors": - + def make( + num_tokens: int, + num_topk: int, + num_experts: int, + device: str, + topk_ids_dtype: torch.dtype, + ) -> "TestTensors": # make topk ids - topk_ids = torch.empty((num_tokens, num_topk), - device=device, - dtype=torch.int64) + topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64) for x in range(num_tokens): topk_ids[x] = torch.randperm(num_experts)[:num_topk] topk_ids = topk_ids.to(dtype=torch.int64) return TestTensors(topk_ids=topk_ids) - def with_ep_rank(self, ep_rank: int, num_global_experts: int, - num_local_experts: int, device: str): + def with_ep_rank( + self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str + ): # make an expert map - expert_map = torch.empty((num_global_experts), - device=device, - dtype=torch.int32) + expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32) expert_map.fill_(-1) s = ep_rank * num_local_experts e = s + num_local_experts - expert_map[s:e] = torch.tensor(list(range(num_local_experts)), - device=device) + expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device) - return TestTensors(topk_ids=self.topk_ids.clone(), - expert_map=expert_map) + return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map) def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): @@ -68,73 +66,81 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): expert_num_tokens[eid] += count -def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - +def do_test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): assert num_topk <= num_experts - tt = TestTensors.make(num_tokens, - num_topk, - num_experts, - topk_ids_dtype=topk_ids_dtype, - device="cpu") + tt = TestTensors.make( + num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu" + ) num_global_experts = num_experts assert num_global_experts % ep_size == 0 num_local_experts = num_global_experts // ep_size for ep_rank in range(ep_size): - tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, - num_local_experts, "cpu") + tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu") - ref_expert_num_tokens = torch.zeros((num_local_experts), - device="cpu", - dtype=torch.int32) + ref_expert_num_tokens = torch.zeros( + (num_local_experts), device="cpu", dtype=torch.int32 + ) ref_impl(tt_rank, ref_expert_num_tokens) ref_expert_num_tokens = ref_expert_num_tokens.to("cuda") tt_rank.to_device("cuda") # Test with expert_map triton_expert_num_tokens_w_emap = count_expert_num_tokens( - tt_rank.topk_ids, num_local_experts, tt_rank.expert_map) + tt_rank.topk_ids, num_local_experts, tt_rank.expert_map + ) # Test without expert map topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype) triton_expert_num_tokens_wo_emap = count_expert_num_tokens( - topk_ids, num_local_experts, expert_map=None) + topk_ids, num_local_experts, expert_map=None + ) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_w_emap, - atol=0, - rtol=0) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_wo_emap, - atol=0, - rtol=0) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0 + ) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0 + ) @pytest.mark.parametrize( - "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317]) + "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317] +) @pytest.mark.parametrize("num_topk", [2, 6, 8]) @pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts, - ep_size, topk_ids_dtype) +def test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): + do_test_compute_expert_num_tokens( + num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype + ) @pytest.mark.parametrize("numel", list(range(1, 8192, 11))) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("ep_size", [2]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int, - ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens=numel, - num_topk=1, - num_experts=num_experts, - ep_size=ep_size, - topk_ids_dtype=topk_ids_dtype) +def test_compute_expert_num_tokens_from_numel( + numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype +): + do_test_compute_expert_num_tokens( + num_tokens=numel, + num_topk=1, + num_experts=num_experts, + ep_size=ep_size, + topk_ids_dtype=topk_ids_dtype, + ) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 67984fe7319a..ba9f7edc0e45 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -18,48 +18,50 @@ def cdiv(a, b): return (a + b - 1) // b -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def per_token_cast_to_fp8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * - (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128), - device=x.device, - dtype=x.dtype) + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, 128) * 128), device=x.device, dtype=x.dtype + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - -@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - (32, 1024, 7168, 4096), - (32, 1024, 2048, 7168), -]) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +@pytest.mark.parametrize( + "num_groups, expected_m_per_group, k, n", + [ + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + (32, 1024, 7168, 4096), + (32, 1024, 2048, 7168), + ], +) @pytest.mark.parametrize("out_dtype", [torch.float16]) @pytest.mark.skipif( (lambda x: x is None or x.to_int() != 100)( - current_platform.get_device_capability()), - reason="Block Scaled Grouped GEMM is only supported on SM100.") + current_platform.get_device_capability() + ), + reason="Block Scaled Grouped GEMM is only supported on SM100.", +) def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int, @@ -70,8 +72,7 @@ def test_cutlass_grouped_gemm( device = "cuda" alignment = 128 group_ms = [ - int(expected_m_per_group * random.uniform(0.7, 1.3)) - for _ in range(num_groups) + int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups) ] m = sum([cdiv(m, alignment) * alignment for m in group_ms]) @@ -88,20 +89,22 @@ def test_cutlass_grouped_gemm( expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, cdiv(n, 128), k // 128), - device=device, - dtype=torch.float)) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float + ), + ) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) for i in range(num_groups): - a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] - a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]] + a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]] + a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]] b = y_fp8[0][i].t() b_scale = y_fp8[1][i].t() baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) - ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline + ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline ops.cutlass_blockwise_scaled_grouped_mm( out, diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 5fb49c2da4fe..d61a66723c47 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -10,11 +10,11 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, run_cutlass_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + cutlass_moe_fp8, + run_cutlass_moe_fp8, +) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] @@ -35,12 +35,11 @@ (224, 3072, 1536), (32768, 1024, 1024), # These sizes trigger wrong answers. - #(7232, 2048, 5120), - #(40000, 2048, 5120), + # (7232, 2048, 5120), + # (40000, 2048, 5120), ] -vllm_config = VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1)) +vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -56,22 +55,25 @@ class MOETensors: c_strides2: torch.Tensor @staticmethod - def make_moe_tensors(m: int, k: int, n: int, e: int, - dtype: torch.dtype) -> "MOETensors": + def make_moe_tensors( + m: int, k: int, n: int, e: int, dtype: torch.dtype + ) -> "MOETensors": a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - return MOETensors(a=a, - w1=w1, - w2=w2, - ab_strides1=ab_strides1, - c_strides1=c_strides1, - ab_strides2=ab_strides2, - c_strides2=c_strides2) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) + return MOETensors( + a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2, + ) @dataclasses.dataclass @@ -89,9 +91,9 @@ class MOETensors8Bit(MOETensors): w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d @staticmethod - def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, - per_act_token: bool, - per_out_channel: bool) -> "MOETensors8Bit": + def make_moe_tensors_8bit( + m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool + ) -> "MOETensors8Bit": dtype = torch.half q_dtype = torch.float8_e4m3fn @@ -102,24 +104,21 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, k_b_scales = k if per_out_channel else 1 # Get the right scale for tests. a_q, a_scale = ops.scaled_fp8_quant( - moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token) + moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token + ) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w1[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w2[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel + ) # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d a_d = a_q.float().mul(a_scale).to(dtype) @@ -129,31 +128,39 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() - return MOETensors8Bit(a=moe_tensors_fp16.a, - w1=moe_tensors_fp16.w1, - w2=moe_tensors_fp16.w2, - ab_strides1=moe_tensors_fp16.ab_strides1, - c_strides1=moe_tensors_fp16.c_strides1, - ab_strides2=moe_tensors_fp16.ab_strides2, - c_strides2=moe_tensors_fp16.c_strides2, - a_q=a_q, - w1_q=w1_q, - w2_q=w2_q, - a_scale=a_scale, - w1_scale=w1_scale, - w2_scale=w2_scale, - a_d=a_d, - w1_d=w1_d, - w2_d=w2_d) - - -def run_with_expert_maps(num_experts: int, num_local_experts: int, - **cutlass_moe_kwargs): - + return MOETensors8Bit( + a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d, + ) + + +def run_with_expert_maps( + num_experts: int, num_local_experts: int, **cutlass_moe_kwargs +): def slice_experts(): slice_params = [ - "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", - "c_strides2", "w1_scale", "w2_scale" + "w1_q", + "w2_q", + "ab_strides1", + "ab_strides2", + "c_strides1", + "c_strides2", + "w1_scale", + "w2_scale", ] full_tensors = { k: v @@ -167,9 +174,7 @@ def slice_experts(): # make expert map expert_map = [-1] * num_experts expert_map[s:e] = list(range(num_local_experts)) - expert_map = torch.tensor(expert_map, - dtype=torch.int32, - device="cuda") + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") # update cutlass moe arg with expert_map cutlass_moe_kwargs["expert_map"] = expert_map @@ -186,32 +191,40 @@ def slice_experts(): return out_tensor -def run_8_bit(moe_tensors: MOETensors8Bit, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - per_act_token: bool, - num_local_experts: Optional[int] = None) -> torch.Tensor: - assert not any([ - t is None for t in [ - moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, - moe_tensors.w2_scale, moe_tensors.a_scale +def run_8_bit( + moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + per_act_token: bool, + num_local_experts: Optional[int] = None, +) -> torch.Tensor: + assert not any( + [ + t is None + for t in [ + moe_tensors.w1_q, + moe_tensors.w2_q, + moe_tensors.w1_scale, + moe_tensors.w2_scale, + moe_tensors.a_scale, + ] ] - ]) + ) kwargs = { - 'a': moe_tensors.a, - 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] - 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] - 'topk_weights': topk_weights, - 'topk_ids': topk_ids, - 'w1_scale': moe_tensors.w1_scale, - 'w2_scale': moe_tensors.w2_scale, - 'ab_strides1': moe_tensors.ab_strides1, - 'ab_strides2': moe_tensors.ab_strides2, - 'c_strides1': moe_tensors.c_strides1, - 'c_strides2': moe_tensors.c_strides2, - 'per_act_token': per_act_token, - 'a1_scale': None #moe_tensors.a_scale + "a": moe_tensors.a, + "w1_q": moe_tensors.w1_q, # type: ignore[union-attr] + "w2_q": moe_tensors.w2_q, # type: ignore[union-attr] + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "w1_scale": moe_tensors.w1_scale, + "w2_scale": moe_tensors.w2_scale, + "ab_strides1": moe_tensors.ab_strides1, + "ab_strides2": moe_tensors.ab_strides2, + "c_strides1": moe_tensors.c_strides1, + "c_strides2": moe_tensors.c_strides2, + "per_act_token": per_act_token, + "a1_scale": None, # moe_tensors.a_scale } num_experts = moe_tensors.w1.size(0) @@ -223,7 +236,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, return run_with_expert_maps( num_experts, num_local_experts, # type: ignore[arg-type] - **kwargs) + **kwargs, + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -233,8 +247,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_no_graph( m: int, n: int, @@ -249,34 +265,29 @@ def test_cutlass_moe_8_bit_no_graph( current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) if ep_size is not None: assert e % ep_size == 0, "Cannot distribute experts evenly" number_local_experts = e // ep_size else: number_local_experts = None - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, - number_local_experts) + cutlass_output = run_8_bit( + mt, topk_weights, topk_ids, per_act_token, number_local_experts + ) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close( + triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2 + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -286,8 +297,10 @@ def test_cutlass_moe_8_bit_no_graph( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_cuda_graph( m: int, n: int, @@ -303,34 +316,25 @@ def test_cutlass_moe_8_bit_cuda_graph( with set_current_vllm_config(vllm_config): dtype = torch.half - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, - per_act_token) + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) + torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) @pytest.mark.parametrize("m", [64]) @@ -343,8 +347,10 @@ def test_cutlass_moe_8_bit_cuda_graph( @pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP( m: int, n: int, @@ -356,8 +362,9 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) LARGE_MNK_FACTORS = [ @@ -374,8 +381,10 @@ def test_cutlass_moe_8_bit_EP( @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP_large( m: int, n: int, @@ -387,8 +396,9 @@ def test_cutlass_moe_8_bit_EP_large( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) @pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) @@ -398,8 +408,10 @@ def test_cutlass_moe_8_bit_EP_large( @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_run_cutlass_moe_fp8( m: int, n: int, @@ -412,14 +424,12 @@ def test_run_cutlass_moe_fp8( ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_channel) + mt = MOETensors8Bit.make_moe_tensors_8bit( + m, k, n, e, per_act_token, per_out_channel + ) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # we want to make sure there is at least one token that's generated in # this expert shard and at least one token that's NOT generated in this # expert shard @@ -430,12 +440,12 @@ def test_run_cutlass_moe_fp8( workspace2_shape = (m * topk, n) output_shape = (m * topk, k) - workspace13 = torch.empty(prod(workspace13_shape), - device="cuda", - dtype=mt.a.dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device="cuda", - dtype=mt.a.dtype) + workspace13 = torch.empty( + prod(workspace13_shape), device="cuda", dtype=mt.a.dtype + ) + workspace2 = torch.empty( + prod(workspace2_shape), device="cuda", dtype=mt.a.dtype + ) num_local_experts = e // ep_size start, end = 0, num_local_experts @@ -443,36 +453,54 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) - a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, - torch.float8_e4m3fn, - per_act_token) + a1q, a1q_scale = moe_kernel_quantize_input( + mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token + ) global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0) func = lambda output: run_cutlass_moe_fp8( - output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, - global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, - workspace13, workspace2, None, mt.a.dtype, per_act_token, - per_out_channel, False) + output, + a1q, + mt.w1_q, + mt.w2_q, + topk_ids, + activation, + global_num_experts, + expert_map, + mt.w1_scale, + mt.w2_scale, + a1q_scale, + None, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + workspace13, + workspace2, + None, + mt.a.dtype, + per_act_token, + per_out_channel, + False, + ) workspace13.random_() - output_random_workspace = torch.empty(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_random_workspace = torch.empty( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_random_workspace) workspace13.fill_(0) - output_zero_workspace = torch.zeros(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_zero_workspace = torch.zeros( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_zero_workspace) - torch.testing.assert_close(output_random_workspace, - output_zero_workspace, - atol=5e-3, - rtol=1e-3) + torch.testing.assert_close( + output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3 + ) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 074771e49a06..945a61d9e6d2 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -16,8 +16,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used @@ -27,18 +26,19 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) + DeepEPHTPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) - from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts) + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), @@ -55,9 +55,10 @@ def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) def make_block_quant_fp8_weights( @@ -70,7 +71,8 @@ def make_block_quant_fp8_weights( Return weights w1q, w2q, w1_scale, w2_scale """ w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( - e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) + e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size + ) return w1q, w2q, w1_scale, w2_scale @@ -98,15 +100,15 @@ class TestTensors: @staticmethod def make(config: TestConfig, rank) -> "TestTensors": - dtype = torch.bfloat16 topk, m, k = (config.topk, config.m, config.k) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - rank_tokens = torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + rank_tokens = ( + torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + ) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_token_scales = None @@ -114,24 +116,31 @@ def make(config: TestConfig, rank) -> "TestTensors": low=0, high=config.num_experts, size=(m, topk), - device=torch.cuda.current_device()).to(dtype=torch.int64) - - topk_weights = torch.randn(topk_ids.shape, - dtype=torch.float32, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ).to(dtype=torch.int64) - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk_ids, - topk_weights=topk_weights, - config=config) + topk_weights = torch.randn( + topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device() + ) + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk_ids, + topk_weights=topk_weights, + config=config, + ) -def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - max_tokens_per_rank: int, dp_size: int, - hidden_size: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: +def make_ll_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + max_tokens_per_rank: int, + dp_size: int, + hidden_size: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, +) -> FusedMoEModularKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -144,25 +153,30 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank=max_tokens_per_rank, hidden_size=hidden_size, num_experts=test_config.num_experts, - use_fp8_dispatch=test_config.use_fp8_dispatch), + use_fp8_dispatch=test_config.use_fp8_dispatch, + ), q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, block_shape=test_config.block_size, - per_act_token_quant=test_config.per_act_token_quant) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + per_act_token_quant=test_config.per_act_token_quant, + ) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: - +def make_ht_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, +) -> FusedMoEModularKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -173,62 +187,68 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts), deepep_ll_args=None, q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) fused_experts = DeepGemmExperts() - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, - test_tensors: TestTensors) -> FusedMoEModularKernel: - +def make_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + test_tensors: TestTensors, +) -> FusedMoEModularKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config mk: FusedMoEModularKernel # Make modular kernel if test_config.low_latency: - max_tokens_per_rank = max( - 64, next_power_of_2(test_tensors.rank_tokens.size(0))) + max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0))) hidden_size = test_tensors.rank_tokens.size(-1) - mk = make_ll_modular_kernel(pg=pg, - pgi=pgi, - max_tokens_per_rank=max_tokens_per_rank, - dp_size=dp_size, - hidden_size=hidden_size, - q_dtype=q_dtype, - test_config=test_config) + mk = make_ll_modular_kernel( + pg=pg, + pgi=pgi, + max_tokens_per_rank=max_tokens_per_rank, + dp_size=dp_size, + hidden_size=hidden_size, + q_dtype=q_dtype, + test_config=test_config, + ) else: - mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, - q_dtype, test_config) + mk = make_ht_modular_kernel( + pg, pgi, dp_size, num_local_experts, q_dtype, test_config + ) return mk -def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, test_tensors: TestTensors, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor]) -> torch.Tensor: - +def deepep_deepgemm_moe_impl( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + test_tensors: TestTensors, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], +) -> torch.Tensor: test_config = test_tensors.config num_experts = test_config.num_experts num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( @@ -236,36 +256,44 @@ def build_expert_map(): pgi=pgi, dp_size=dp_size, num_local_experts=num_local_experts, - test_tensors=test_tensors) + test_tensors=test_tensors, + ) # Low-Latency kernels can't dispatch scales. - a1_scale = (None - if test_config.low_latency else test_tensors.rank_token_scales) - - out = mk.forward(hidden_states=test_tensors.rank_tokens, - w1=w1, - w2=w2, - topk_weights=test_tensors.topk_weights, - topk_ids=test_tensors.topk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=a1_scale, - a2_scale=None, - apply_router_weight_on_input=False) - return out + a1_scale = None if test_config.low_latency else test_tensors.rank_token_scales + out = mk.forward( + hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=None, + w2_zp=None, + a1_scale=a1_scale, + a2_scale=None, + apply_router_weight_on_input=False, + ) + return out -def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, - topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a1_scale: torch.Tensor, block_shape: list[int]): +def triton_impl( + a: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: list[int], +): return fused_experts( hidden_states=a, w1=w1, @@ -280,7 +308,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, block_shape=block_shape, # Make sure this is set to False so we # dont end up comparing the same implementation. - allow_deep_gemm=False) + allow_deep_gemm=False, + ) def _test_deepep_deepgemm_moe( @@ -301,22 +330,21 @@ def _test_deepep_deepgemm_moe( pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, pgi.rank) - block_shape = [ - w1.size(1) // w1_scale.size(1), - w1.size(2) // w1_scale.size(2) - ] + block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] with set_current_vllm_config(VllmConfig()): # Reference - triton_moe = triton_impl(a=test_tensors.rank_tokens, - topk_ids=test_tensors.topk, - topk_weights=test_tensors.topk_weights, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=test_tensors.rank_token_scales, - block_shape=block_shape) + triton_moe = triton_impl( + a=test_tensors.rank_tokens, + topk_ids=test_tensors.topk, + topk_weights=test_tensors.topk_weights, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=test_tensors.rank_token_scales, + block_shape=block_shape, + ) # Slice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -369,10 +397,15 @@ def _test_deepep_deepgemm_moe( @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), - reason="Skipping test for Blackwell DeepGEMM") -def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, - topk: int, world_dp_size: tuple[int, int]): +@pytest.mark.skipif( + is_blackwell_deep_gemm_used(), reason="Skipping test for Blackwell DeepGEMM" +) +def test_ht_deepep_deepgemm_moe( + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], +): """ Tests for High-Throughput DeepEP + DeepGemm integration. """ @@ -388,21 +421,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, block_size = [block_m, block_m] world_size, dp_size = world_dp_size - config = TestConfig(topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts, - per_act_token_quant=False, - block_size=block_size, - low_latency=False, - use_fp8_dispatch=None) + config = TestConfig( + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts, + per_act_token_quant=False, + block_size=block_size, + low_latency=False, + use_fp8_dispatch=None, + ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) MNKs = [ @@ -426,8 +470,9 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), - reason="Skipping test for Blackwell DeepGEMM") +@pytest.mark.skipif( + is_blackwell_deep_gemm_used(), reason="Skipping test for Blackwell DeepGEMM" +) def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, @@ -460,7 +505,16 @@ def test_ll_deepep_deepgemm_moe( ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 43804c410b6c..5f5a17dc6714 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -15,12 +15,11 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.platforms import current_platform from vllm.utils import has_deep_ep @@ -28,9 +27,11 @@ if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) + DeepEPHTPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a @@ -43,7 +44,7 @@ def make_weights( - e, n, k, dtype + e, n, k, dtype ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Return weights w1, w2, w1_scale, w2_scale @@ -62,17 +63,15 @@ def make_weights( k_b_scales = k w1_q = torch.empty_like(w1, dtype=dtype) w2_q = torch.empty_like(w2, dtype=dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=True) + w1[expert], use_per_token_if_dynamic=True + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=True) + w2[expert], use_per_token_if_dynamic=True + ) return w1_q, w2_q, w1_scale, w2_scale @@ -98,24 +97,25 @@ class TestTensors: def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": # TODO (varun) - check that float16 works ? assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn] - token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn - else config.dtype) - rank_tokens = torch.randn( - (config.m, config.k), device="cuda", dtype=token_dtype) / 10 + token_dtype = ( + torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype + ) + rank_tokens = ( + torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10 + ) rank_token_scales = None - topk = torch.randint(low=0, - high=config.num_experts, - size=(config.m, config.topk), - device="cuda").to(dtype=torch.int64) - topk_weights = torch.randn(topk.shape, - dtype=torch.float32, - device="cuda") - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk, - topk_weights=topk_weights, - config=config) + topk = torch.randint( + low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda" + ).to(dtype=torch.int64) + topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda") + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk, + topk_weights=topk_weights, + config=config, + ) def make_modular_kernel( @@ -130,30 +130,35 @@ def make_modular_kernel( use_fp8_dispatch: bool, per_act_token_quant: bool, ) -> FusedMoEModularKernel: - is_quantized = q_dtype is not None ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None if low_latency_mode: - ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK, - hidden_size=hidden_size, - num_experts=num_experts, - use_fp8_dispatch=use_fp8_dispatch) + ll_args = DeepEPLLArgs( + max_tokens_per_rank=MAX_TOKENS_PER_RANK, + hidden_size=hidden_size, + num_experts=num_experts, + use_fp8_dispatch=use_fp8_dispatch, + ) else: assert not use_fp8_dispatch, ( - "FP8 Dispatch is valid only for low-latency kernels") + "FP8 Dispatch is valid only for low-latency kernels" + ) ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) - a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \ - make_deepep_a2a(pg = pg, - pgi = pgi, - dp_size = dp_size, - q_dtype = q_dtype, - block_shape = None, - deepep_ht_args = ht_args, - deepep_ll_args = ll_args) + a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = ( + make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + q_dtype=q_dtype, + block_shape=None, + deepep_ht_args=ht_args, + deepep_ll_args=ll_args, + ) + ) num_dispatchers = pgi.world_size // dp_size @@ -177,8 +182,7 @@ def make_modular_kernel( per_act_token_quant=per_act_token_quant, ) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -196,19 +200,15 @@ def deep_ep_moe_impl( use_fp8_dispatch: bool, per_act_token_quant: bool, ) -> torch.Tensor: - num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) hidden_size = test_tensors.rank_tokens.size(1) is_quantized = w1.dtype == torch.float8_e4m3fn @@ -218,8 +218,17 @@ def build_expert_map(): # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, - num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant) + pg, + pgi, + low_latency_mode, + hidden_size, + dp_size, + num_experts, + num_local_experts, + q_dtype, + use_fp8_dispatch, + per_act_token_quant, + ) out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) @@ -229,35 +238,38 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end] topk_chunk = test_tensors.topk[chunk_start:chunk_end] rank_token_scales_chunk = test_tensors.rank_token_scales - if rank_token_scales_chunk is not None and rank_token_scales_chunk.size( - 0) == total_num_tokens: + if ( + rank_token_scales_chunk is not None + and rank_token_scales_chunk.size(0) == total_num_tokens + ): # per act token - rank_token_scales_chunk = rank_token_scales_chunk[ - chunk_start:chunk_end] - - out = mk.forward(hidden_states=rank_tokens_chunk, - w1=w1, - w2=w2, - topk_weights=topk_weights_chunk, - topk_ids=topk_chunk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=rank_token_scales_chunk, - a2_scale=None, - apply_router_weight_on_input=False) + rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end] + + out = mk.forward( + hidden_states=rank_tokens_chunk, + w1=w1, + w2=w2, + topk_weights=topk_weights_chunk, + topk_ids=topk_chunk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=None, + w2_zp=None, + a1_scale=rank_token_scales_chunk, + a2_scale=None, + apply_router_weight_on_input=False, + ) if not skip_result_store: - out_hidden_states[chunk_start:chunk_end, :].copy_( - out, non_blocking=True) + out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True) - max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK - if low_latency_mode else total_num_tokens) + max_num_tokens_per_dp = ( + MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens + ) for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp): chunk_start = chunk_start_ @@ -266,9 +278,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_start = min(chunk_start, total_num_tokens - 1) chunk_end = min(chunk_end, total_num_tokens) - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= total_num_tokens) + process_chunk( + chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens + ) return out_hidden_states @@ -282,9 +294,11 @@ def torch_moe_impl( using_fp8_dispatch: bool, per_act_token_quant: bool, ): - - a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, - test_tensors.topk_weights) + a, topk_ids, topk_weights = ( + test_tensors.rank_tokens, + test_tensors.topk, + test_tensors.topk_weights, + ) if using_fp8_dispatch: # The DeepEP implementation is requested to dispatch using FP8. # For numerical stability for testing, emulate the fp8 dispatch by @@ -292,8 +306,11 @@ def torch_moe_impl( assert not per_act_token_quant a = test_tensors.rank_tokens aq, aq_scale = per_token_group_quant_fp8(a, 128) - a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( - a.shape).to(a.dtype) + a = ( + (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) + .view(a.shape) + .to(a.dtype) + ) is_quantized = w1.dtype == torch.float8_e4m3fn a_dtype = a.dtype @@ -314,8 +331,9 @@ def torch_moe_impl( e_w = topk_weights[i][j] w1_e = w1[e] w2_e = w2[e] - o_i += (SiluAndMul() - (a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w + o_i += ( + SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1) + ) * e_w if is_quantized: out = out.to(dtype=a_dtype) @@ -335,28 +353,36 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): - if not low_latency_mode: assert not use_fp8_dispatch, ( - "FP8 dispatch interface is available only in low-latency mode") + "FP8 dispatch interface is available only in low-latency mode" + ) is_quantized = w1.dtype == torch.float8_e4m3fn w1 = w1.to(device=torch.cuda.current_device()) w2 = w2.to(device=torch.cuda.current_device()) if is_quantized: w1_scale = w1_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) w2_scale = w2_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): # Reference - torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, - w2_scale, use_fp8_dispatch, - per_act_token_quant) + torch_combined = torch_moe_impl( + test_tensors, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) # Splice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -426,18 +452,23 @@ def test_deep_ep_moe( current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - per_act_token_quant) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) MNKs = [ @@ -460,16 +491,18 @@ def test_deep_ep_moe( @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @requires_deep_ep -def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], - num_experts: int, topk: int, - world_dp_size: tuple[int, int], - use_fp8_dispatch: bool): - +def test_low_latency_deep_ep_moe( + dtype: torch.dtype, + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], + use_fp8_dispatch: bool, +): low_latency_mode = True m, n, k = mnk - if (low_latency_mode - and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): + if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES: pytest.skip( f"Skipping test as hidden size {k} is not in list of supported " f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" @@ -477,15 +510,20 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - False) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + False, + ) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index f7578e226917..360d809aed97 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -14,7 +14,8 @@ # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8 @@ -40,8 +41,10 @@ def make_block_quant_fp8_weights( w2 shape: (E, K, N) """ dtype = torch.bfloat16 - fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( - torch.float8_e4m3fn).min + fp8_max, fp8_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) # bf16 reference weights w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 @@ -57,16 +60,8 @@ def make_block_quant_fp8_weights( w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.empty(e, - n_tiles_w1, - k_tiles_w1, - device="cuda", - dtype=torch.float32) - w2_s = torch.empty(e, - n_tiles_w2, - k_tiles_w2, - device="cuda", - dtype=torch.float32) + w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32) + w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32) for i in range(e): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) @@ -80,18 +75,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size): Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == Triton baseline within tolerance. """ - tokens_bf16 = torch.randn( - m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + tokens_bf16 = ( + torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + .clamp_min_(-1) + .clamp_max_(1) + ) _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) # expert weight tensors - w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, - block_size) + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size) - router_logits = torch.randn(m, - num_experts, - device="cuda", - dtype=torch.float32) + router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32) topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) @@ -150,12 +144,12 @@ def run_single_case(m, n, k, topk, num_experts, block_size): @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @requires_deep_gemm def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): - with monkeypatch.context() as m: m.setenv("VLLM_USE_DEEP_GEMM", "1") _fused_moe_mod = importlib.import_module( - "vllm.model_executor.layers.fused_moe.fused_moe") + "vllm.model_executor.layers.fused_moe.fused_moe" + ) call_counter = {"cnt": 0} @@ -165,8 +159,7 @@ def _spy_deep_gemm_moe_fp8(*args, **kwargs): call_counter["cnt"] += 1 return orig_fn(*args, **kwargs) - monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", - _spy_deep_gemm_moe_fp8) + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) m, n, k = mnk @@ -183,6 +176,7 @@ def _spy_deep_gemm_moe_fp8(*args, **kwargs): ) # ensure that the DeepGEMM path was indeed taken. - assert call_counter["cnt"] == 1, \ - f"DeepGEMM path was not executed during the test. " \ + assert call_counter["cnt"] == 1, ( + f"DeepGEMM path was not executed during the test. " f"Call counter: {call_counter['cnt']}" + ) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6f2869c3a61d..2eef8fdcf508 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -11,27 +11,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.layer import TritonExperts from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx -from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, - reference_moe_impl, - run_modular_kernel) +from .modular_kernel_tools.common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) from .modular_kernel_tools.mk_objects import ( - MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) -from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, - parallel_launch_with_config) + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, +) +from .modular_kernel_tools.parallel_utils import ( + ProcessGroupInfo, + parallel_launch_with_config, +) # TODO (varun): These requirements are very strict and could be relaxed. -has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) +has_all_packages = has_deep_ep() and has_deep_gemm() and has_pplx() meets_package_requirements = pytest.mark.skipif( not has_all_packages, @@ -50,8 +60,9 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -72,8 +83,7 @@ def rank_worker( rank_tensors = RankTensors.make(cfgx, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) with set_current_vllm_config(vllm_config): ref_out = reference_moe_impl(cfgx, weights, rank_tensors) @@ -88,8 +98,9 @@ def run(config: Config): weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights + ) Ms = [32, 64] @@ -104,14 +115,17 @@ def run(config: Config): def is_nyi_config(config: Config) -> bool: # We know these configs to be legitimate. but still fail. - if (config.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - TritonExperts, TritonOrDeepGemmExperts - ]): + if config.fused_experts_type in [ + BatchedTritonExperts, + BatchedTritonOrDeepGemmExperts, + TritonExperts, + TritonOrDeepGemmExperts, + ]: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. - unsupported_quant_config = ((config.is_per_act_token_quant + - config.is_per_out_ch_quant) == 1) + unsupported_quant_config = ( + config.is_per_act_token_quant + config.is_per_out_ch_quant + ) == 1 return unsupported_quant_config # cutlass kernels dont support expert_maps yet. @@ -124,18 +138,23 @@ def is_nyi_config(config: Config) -> bool: @pytest.mark.parametrize("dtype", DTYPEs) @pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", - product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) + "combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) +) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) @meets_package_requirements def test_modular_kernel_combinations_multigpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): - + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[ + mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute + ], + fused_moe_chunk_size: Optional[int], + world_size: int, +): config = Config( Ms=Ms, K=k, @@ -165,17 +184,23 @@ def test_modular_kernel_combinations_multigpu( @pytest.mark.parametrize("dtype", DTYPEs) @pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", - product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) + "combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) +) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [1]) @meets_package_requirements def test_modular_kernel_combinations_singlegpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[ + mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute + ], + fused_moe_chunk_size: Optional[int], + world_size: int, +): config = Config( Ms=Ms, K=k, @@ -199,15 +224,17 @@ def test_modular_kernel_combinations_singlegpu( run(config) -if __name__ == '__main__': +if __name__ == "__main__": # Ability to test individual PrepareAndFinalize and FusedExperts combination - from .modular_kernel_tools.cli_args import (make_config, - make_config_arg_parser) - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " # noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() config = make_config(args) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0f1c78704642..54e5525ab0a9 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_moe.py`. """ + import functools from typing import Callable, Optional, Union @@ -21,17 +22,23 @@ from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) + fused_moe as iterative_moe, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like) + rand_marlin_weight_fp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize, marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + awq_marlin_quantize, + marlin_quantize, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -64,13 +71,15 @@ def run_moe_test( if isinstance(baseline, torch.Tensor): baseline_output = baseline else: - baseline_output = baseline(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + baseline_output = baseline( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) # Pad the weight if moe padding is enabled if padding: @@ -82,34 +91,35 @@ def run_moe_test( torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(score, 0) - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) if use_cudagraph: test_output.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(test_output, - baseline_output, - atol=atol, - rtol=rtol) + torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) return baseline_output @@ -155,11 +165,8 @@ def test_fused_moe( if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -170,13 +177,15 @@ def test_fused_moe( # Setup test functions # - m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=None) + m_fused_moe_fn = modular_triton_fused_moe( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + use_mxfp4_w4a4=False, + per_act_token_quant=False, + block_shape=None, + ) def m_fused_moe( a: torch.Tensor, @@ -188,13 +197,15 @@ def m_fused_moe( expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - return m_fused_moe_fn(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map) + return m_fused_moe_fn( + a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) @@ -218,19 +229,22 @@ def m_fused_moe( # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (n >= 1024 and k >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) - runner(baseline_output, - fused_moe_fn, - use_compile=use_compile, - use_cudagraph=use_cudagraph) - runner(baseline_output, - m_fused_moe, - use_compile=use_compile, - use_cudagraph=use_cudagraph) + runner( + baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) + runner( + baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) @pytest.mark.parametrize("m", [1, 32, 222]) @@ -243,9 +257,18 @@ def m_fused_moe( @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) -def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, - ep_size: int, dtype: torch.dtype, group_size: int, - has_zp: bool, weight_bits: int): +def test_fused_moe_wn16( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + group_size: int, + has_zp: bool, + weight_bits: int, +): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -260,35 +283,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_ref = w1.clone() w2_ref = w2.clone() - w1_qweight = torch.empty((e, 2 * n, k // pack_factor), - device="cuda", - dtype=torch.uint8) - w2_qweight = torch.empty((e, k, n // pack_factor), - device="cuda", - dtype=torch.uint8) - w1_scales = torch.empty((e, 2 * n, k // group_size), - device="cuda", - dtype=dtype) - w2_scales = torch.empty((e, k, n // group_size), - device="cuda", - dtype=dtype) - w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), - device="cuda", - dtype=torch.uint8) - w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), - device="cuda", - dtype=torch.uint8) + w1_qweight = torch.empty( + (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + ) + w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w1_qzeros = torch.empty( + (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + ) + w2_qzeros = torch.empty( + (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + ) for i in range(e * 2): expert_id = i % e if i // e == 0: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w1, w1_ref, w1_qweight, w1_scales, w1_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w1, + w1_ref, + w1_qweight, + w1_scales, + w1_qzeros, + ) else: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w2, w2_ref, w2_qweight, w2_scales, w2_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w2, + w2_ref, + w2_qweight, + w2_scales, + w2_qzeros, + ) weight, qweight, scales, qzeros = quantize_weights( - w[expert_id].T, quant_type, group_size, has_zp, False) + w[expert_id].T, quant_type, group_size, has_zp, False + ) weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T @@ -307,11 +335,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1_ref = w1_ref[e_ids] w2_ref = w2_ref[e_ids] @@ -325,45 +350,45 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, e_map = None with set_current_vllm_config(vllm_config): - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, - w1_ref, - w2_ref, - score, - topk, - expert_map=e_map) + triton_output = fused_moe( + a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size], + ) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, - monkeypatch): +def test_mixtral_moe( + dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch +): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" # clear the cache before every test from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, + ) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -371,17 +396,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") - monkeypatch.setenv('RANK', "0") - monkeypatch.setenv('LOCAL_RANK', "0") - monkeypatch.setenv('WORLD_SIZE', "1") - monkeypatch.setenv('MASTER_ADDR', 'localhost') - monkeypatch.setenv('MASTER_PORT', '12345') + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "12345") init_distributed_environment() # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() - with (set_current_vllm_config(vllm_config), - set_forward_context(None, vllm_config)): + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") vllm_moe = MixtralMoE( @@ -397,28 +421,31 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) + weights = ( + hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data, + ) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn( - (1, 64, config.hidden_size)).to(dtype).to("cuda") + hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) # Pad the weight if moe padding is enabled if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) + vllm_moe.experts.w13_weight = Parameter( + F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[ + ..., 0:-128 + ], + requires_grad=False, + ) torch.cuda.empty_cache() - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) + vllm_moe.experts.w2_weight = Parameter( + F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], + requires_grad=False, + ) torch.cuda.empty_cache() # Run forward passes for both MoE blocks @@ -434,19 +461,21 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, if use_rocm_aiter: # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501 # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501 - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=0.01, - atol=100) + torch.testing.assert_close( + hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100 + ) else: - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=mixtral_moe_tol[dtype], - atol=mixtral_moe_tol[dtype]) + torch.testing.assert_close( + hf_states.flatten(0, 1), + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype], + ) def marlin_moe_generate_valid_test_cases(): import itertools + m_list = [1, 123, 666] n_list = [128, 1024] k_list = [256, 2048] @@ -465,16 +494,24 @@ def marlin_moe_generate_valid_test_cases(): ] is_k_full_list = [True, False] - all_combinations = itertools.product(m_list, n_list, k_list, e_list, - topk_list, ep_size_list, dtype_list, - group_size_list, act_order_list, - quant_type_list, is_k_full_list) - - def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, - quant_type, is_k_full): + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + ep_size_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + is_k_full_list, + ) - if quant_type == scalar_types.float8_e4m3fn and \ - group_size not in [-1, 128]: + def is_invalid( + m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full + ): + if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: return False if quant_type == scalar_types.float4_e2m1f and group_size != 16: return False @@ -500,9 +537,10 @@ def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, @pytest.mark.flaky(reruns=2) -@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," - "act_order, quant_type, is_k_full"), - marlin_moe_generate_valid_test_cases()) +@pytest.mark.parametrize( + ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases(), +) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -552,7 +590,7 @@ def test_fused_marlin_moe( if ep_size > 1: local_e = e // ep_size e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -569,22 +607,23 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref1, qweight1, scales1, global_scale1 = \ - rand_marlin_weight_fp4_like(w1[i], group_size) + w_ref1, qweight1, scales1, global_scale1 = rand_marlin_weight_fp4_like( + w1[i], group_size + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size) + w1[i].transpose(1, 0), quant_type, group_size + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -592,9 +631,9 @@ def test_fused_marlin_moe( zeros1_l.append(zeros1) else: test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ - marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -620,22 +659,23 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref2, qweight2, scales2, global_scale2 = \ - rand_marlin_weight_fp4_like(w2[i], group_size) + w_ref2, qweight2, scales2, global_scale2 = rand_marlin_weight_fp4_like( + w2[i], group_size + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size) + w2[i].transpose(1, 0), quant_type, group_size + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -643,9 +683,9 @@ def test_fused_marlin_moe( zeros2_l.append(zeros2) else: test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ - marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -666,12 +706,7 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, - w_ref1, - w_ref2, - score, - topk, - expert_map=e_map) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -693,7 +728,8 @@ def test_fused_marlin_moe( w1_zeros=zeros1, w2_zeros=zeros2, quant_type_id=quant_type.id, - is_k_full=is_k_full) + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -701,34 +737,36 @@ def test_fused_marlin_moe( def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 - topk_ids = torch.randint(0, - num_experts, (3, 4), - dtype=torch.int32, - device='cuda') + topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - - opcheck(torch.ops._moe_C.moe_align_block_size, - (topk_ids, num_experts, block_size, sorted_ids, expert_ids, - num_tokens_post_pad)) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + + opcheck( + torch.ops._moe_C.moe_align_block_size, + ( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): input = torch.randn((m, topk, k), device="cuda", dtype=dtype) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index e980422a7b97..e491b2d7898c 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -7,7 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size_triton) + moe_align_block_size_triton, +) @pytest.mark.parametrize( @@ -26,28 +27,32 @@ ], # num_tokens [1, 4, 16, 64], # topk [64, 160, 256, 257, 260, 264], # num_experts - )), + ) + ), ) -def test_moe_align_block_size_compare_implementations(block_size, num_tokens, - topk, num_experts): - topk_ids = torch.stack([ - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] - for _ in range(num_tokens) - ]) +def test_moe_align_block_size_compare_implementations( + block_size, num_tokens, topk, num_experts +): + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ] + ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids_cuda = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids_cuda = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids_cuda.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids_cuda = torch.zeros((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad_cuda = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) + expert_ids_cuda = torch.zeros( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad_cuda = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) sorted_ids_triton = torch.empty_like(sorted_ids_cuda) sorted_ids_triton.fill_(topk_ids.numel()) @@ -76,14 +81,15 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, f"Expert IDs mismatch for block_size={block_size}, " f"num_tokens={num_tokens}, topk={topk}\n" f"CUDA expert_ids: {expert_ids_cuda}\n" - f"Triton expert_ids: {expert_ids_triton}") + f"Triton expert_ids: {expert_ids_triton}" + ) - assert torch.allclose( - num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( - f"Num tokens post pad mismatch for block_size={block_size}, " - f"num_tokens={num_tokens}, topk={topk}\n" - f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" - f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}") + assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( + f"Num tokens post pad mismatch for block_size={block_size}, " + f"num_tokens={num_tokens}, topk={topk}\n" + f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" + f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}" + ) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 7cc83b512c8b..403f018be61f 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -14,7 +14,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_permute_unpermute_supported, moe_unpermute) + moe_permute, + moe_permute_unpermute_supported, + moe_unpermute, +) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64] @@ -23,30 +26,32 @@ current_platform.seed_everything(0) -def torch_permute(hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: +def torch_permute( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1, +) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: - is_local_expert = (expert_map[topk_ids] != -1) - not_local_expert = (expert_map[topk_ids] == -1) - topk_ids = is_local_expert * ( - topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + is_local_expert = expert_map[topk_ids] != -1 + not_local_expert = expert_map[topk_ids] == -1 + topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * ( + topk_ids + n_expert + ) - sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), - stable=True) + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] - expert_first_token_offset = torch.zeros(n_local_expert + 1, - dtype=torch.int64, - device="cuda") + expert_first_token_offset = torch.zeros( + n_local_expert + 1, dtype=torch.int64, device="cuda" + ) idx = 0 for i in range(0, n_local_expert): cnt = 0 @@ -58,101 +63,116 @@ def torch_permute(hidden_states: torch.Tensor, _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) valid_row_idx = [] if align_block_size is None: - - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % - n_token, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % n_token, ...] permuted_row_size = permuted_hidden_states.shape[0] - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] m_indices[first_token_offset:last_token_offset] = i - 1 src_row_id2dst_row_id_map = torch.arange( - 0, n_token * topk, device="cuda", - dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + 0, n_token * topk, device="cuda", dtype=torch.int32 + )[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] return [ - permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices, valid_row_idx + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, + m_indices, + valid_row_idx, ] else: - permuted_row_size = (topk * n_token + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size - permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), - device="cuda", - dtype=hidden_states.dtype) - align_src_row_id2dst_row_id = torch.empty(n_token * topk, - device="cuda", - dtype=torch.int32) - align_expert_first_token_offset = torch.zeros_like( - expert_first_token_offset) - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + permuted_row_size = ( + (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) + // align_block_size + * align_block_size + ) + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype + ) + align_src_row_id2dst_row_id = torch.empty( + n_token * topk, device="cuda", dtype=torch.int32 + ) + align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) # get align_permuted_hidden_states, # valid row_idx and align_expert_first_token_offset for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] n_token_in_expert = last_token_offset - first_token_offset - align_expert_first_token_offset[ - i] = align_expert_first_token_offset[ - i - 1] + (n_token_in_expert + align_block_size - - 1) // align_block_size * align_block_size + align_expert_first_token_offset[i] = ( + align_expert_first_token_offset[i - 1] + + (n_token_in_expert + align_block_size - 1) + // align_block_size + * align_block_size + ) align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] - dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + - n_token_in_expert] % n_token + dst_row_id2src_row_id_in_expert = ( + dst_row_id2src_row_id_map[ + first_token_offset : first_token_offset + n_token_in_expert + ] + % n_token + ) # store token in current expert with align_first_token_offset - permuted_hidden_states[align_first_token_offset:\ - align_first_token_offset+n_token_in_expert,\ - ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert, ...] + permuted_hidden_states[ + align_first_token_offset : align_first_token_offset + n_token_in_expert, + ..., + ] = hidden_states[dst_row_id2src_row_id_in_expert, ...] # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ - i for i in range(align_first_token_offset, - align_first_token_offset + n_token_in_expert) + i + for i in range( + align_first_token_offset, + align_first_token_offset + n_token_in_expert, + ) ] # get align_src_row_id2dst_row_id for i in range(n_token * topk): eid = sorted_topk_ids[i] - if (eid >= n_local_expert): + if eid >= n_local_expert: # check token not in local expert - align_src_row_id2dst_row_id[ - i] = align_expert_first_token_offset[-1] + align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1] continue first_token_offset = expert_first_token_offset[eid] align_first_token_offset = align_expert_first_token_offset[eid] token_offset = i - first_token_offset - align_src_row_id2dst_row_id[ - i] = align_first_token_offset + token_offset - align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ - src2dst_idx].reshape((n_token, topk)) + align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape( + (n_token, topk) + ) return [ - permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, m_indices, valid_row_idx + permuted_hidden_states, + align_expert_first_token_offset, + align_src_row_id2dst_row_id, + m_indices, + valid_row_idx, ] -def torch_unpermute(permuted_hidden_states: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - valid_row_idx: torch.Tensor, topk: int, - n_expert: int) -> torch.Tensor: +def torch_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, + topk: int, + n_expert: int, +) -> torch.Tensor: # ignore invalid row - mask = torch.zeros(permuted_hidden_states.shape[0], - dtype=bool, - device="cuda") + mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 - idx = src_row_id2dst_row_id_map.flatten()[ - token_expert_indices.flatten()].reshape(token_expert_indices.shape) + idx = src_row_id2dst_row_id_map.flatten()[token_expert_indices.flatten()].reshape( + token_expert_indices.shape + ) output = permuted_hidden_states[idx, ...] * topk_weights[..., None] output = output.sum(dim=1).to(permuted_hidden_states.dtype) return output @@ -165,25 +185,31 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("align_block_size", [None, 128]) -def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, - n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: Optional[int]): +def test_moe_permute_unpermute( + n_token: int, + n_hidden: int, + topk: int, + n_expert: int, + ep_size: int, + dtype: torch.dtype, + align_block_size: Optional[int], +): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None n_local_expert = n_expert - if (ep_size != 1): - n_local_expert, expert_map = determine_expert_map( - ep_size, ep_rank, n_expert) + if ep_size != 1: + n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank current_platform.seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, False) + hidden_states, gating_output, topk, False + ) gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( hidden_states, topk_ids, @@ -194,12 +220,21 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, start_expert, expert_map=expert_map, align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) + fill_invalid_expert=fill_invalid_expert, + ) result0, result1, result2, result3 = moe_permute( - hidden_states, topk_weights, topk_ids, token_expert_indices, topk, - n_expert, n_local_expert, expert_map, align_block_size, - fill_invalid_expert) + hidden_states, + topk_weights, + topk_ids, + token_expert_indices, + topk, + n_expert, + n_local_expert, + expert_map, + align_block_size, + fill_invalid_expert, + ) # check expert_first_token_offset torch.testing.assert_close(gold1, result1, atol=0, rtol=0) @@ -208,19 +243,33 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, # check mindice torch.testing.assert_close(gold3, result3, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold0[valid_row_idx], - result0[valid_row_idx], - atol=0, - rtol=0) + torch.testing.assert_close( + gold0[valid_row_idx], result0[valid_row_idx], atol=0, rtol=0 + ) # add a random tensor to simulate group gemm result0 = 0.5 * result0 + torch.randn_like(result0) - result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, - topk, n_expert, n_local_expert) - gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, result2, valid_row_idx, topk, - n_local_expert) + result4 = moe_unpermute( + result0, + topk_weights, + topk_ids, + result2, + result1, + topk, + n_expert, + n_local_expert, + ) + gold4 = torch_unpermute( + result0, + topk_weights, + topk_ids, + token_expert_indices, + result2, + valid_row_idx, + topk, + n_local_expert, + ) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 824b072a9f93..5cf8e1bd6e94 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -9,9 +9,9 @@ import torch from packaging import version -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( - "quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") @dataclass @@ -20,22 +20,25 @@ class ModelCase: tp: int -@pytest.mark.parametrize('model_case', [ - ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), - ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), - ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) -]) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize( + "model_case", + [ + ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), + ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), + ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1), + ], +) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): if torch.cuda.device_count() < model_case.tp: - pytest.skip(f"This test requires >={model_case.tp} gpus, got only " - f"{torch.cuda.device_count()}") - - with vllm_runner(model_case.model_id, - tensor_parallel_size=model_case.tp, - load_format="dummy") as llm: - + pytest.skip( + f"This test requires >={model_case.tp} gpus, got only " + f"{torch.cuda.device_count()}" + ) + + with vllm_runner( + model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy" + ) as llm: # TODO: llm.apply_model(check_model) currently relies on V0 internals. # Re-enable this later. # def check_model(model): @@ -52,6 +55,5 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": # llm.apply_model(check_model) - output = llm.generate_greedy("Today I am in the French Alps and", - max_tokens=20) - assert output \ No newline at end of file + output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) + assert output diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 3f5412e75821..f51e081984a5 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -3,9 +3,11 @@ import pytest import torch -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config @@ -14,8 +16,9 @@ from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip("Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True + ) MNK_FACTORS = [ (2, 1024, 1024), @@ -36,36 +39,34 @@ @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() -def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype): +def test_cutlass_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): current_platform.seed_everything(7) with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 round_up = lambda x, y: (x + y - 1) // y * y sf_w1_2n = round_up(2 * n, 128) sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), - device="cuda", - dtype=torch.float8_e4m3fn) + w1_blockscale = torch.empty( + (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 sf_w2_k = round_up(k, 128) sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), - device="cuda", - dtype=torch.float8_e4m3fn) + w2_blockscale = torch.empty( + (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + ) - w1_q = torch.empty((e, 2 * n, k // 2), - device="cuda", - dtype=torch.uint8) + w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) for expert in range(e): w1_amax = torch.abs(w1).max().to(torch.float32) @@ -74,19 +75,18 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( - w1[expert], w1_gs[expert]) + w1[expert], w1_gs[expert] + ) w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( - w2[expert], w2_gs[expert]) + w2[expert], w2_gs[expert] + ) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) cutlass_output = cutlass_moe_fp4( a=a, @@ -108,40 +108,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ) # Reference check: - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) _, m_k = a_fp4.shape - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=a.dtype, - device=a.device, - block_size=quant_blocksize) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=w1.dtype, - device=w1.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=w2.dtype, - device=w2.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize, + ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) - torch.testing.assert_close(torch_output, - cutlass_output, - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 77adc89ea9da..2403450d1bdf 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -11,8 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import cdiv @@ -20,9 +19,13 @@ try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False @@ -46,12 +49,12 @@ def chunk_by_rank(t, r, w): chunk = rank_chunk(num, r, w) rem = num % w if rem == 0 or r < rem: - return t[(r * chunk):(r + 1) * chunk].contiguous() + return t[(r * chunk) : (r + 1) * chunk].contiguous() else: long_chunks = (num // w + 1) * rem short_chunks = (r - rem) * chunk start = long_chunks + short_chunks - return t[start:start + chunk].contiguous() + return t[start : start + chunk].contiguous() def pplx_cutlass_moe( @@ -71,7 +74,9 @@ def pplx_cutlass_moe( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) + assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape @@ -122,35 +127,34 @@ def pplx_cutlass_moe( ata, max_num_tokens=max_num_tokens, num_local_experts=num_local_experts, - num_dispatchers=num_dispatchers) - - ab_strides1 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_local_experts, ), - intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_local_experts, ), - 2 * intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) - - experts = CutlassExpertsFp8(num_local_experts, - out_dtype, - per_act_token, - per_out_ch, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, - num_dispatchers=num_dispatchers, - use_batched_format=True) + num_dispatchers=num_dispatchers, + ) + + ab_strides1 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) + ab_strides2 = torch.full( + (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides1 = torch.full( + (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides2 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) + + experts = CutlassExpertsFp8( + num_local_experts, + out_dtype, + per_act_token, + per_out_ch, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + num_dispatchers=num_dispatchers, + use_batched_format=True, + ) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -158,10 +162,10 @@ def pplx_cutlass_moe( ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - chunk_topk_weight = chunk_by_rank(topk_weights, rank, - world_size).to(device) - chunk_topk_ids = chunk_by_rank(topk_ids, rank, - world_size).to(torch.uint32).to(device) + chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device) + chunk_topk_ids = ( + chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device) + ) out = fused_cutlass_experts( a_chunk, @@ -170,11 +174,13 @@ def pplx_cutlass_moe( chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts, - expert_map=None, #TODO + expert_map=None, # TODO w1_scale=chunk_by_rank(w1_scale, rank, world_size), w2_scale=chunk_by_rank(w2_scale, rank, world_size), a1_scale=chunk_by_rank(a1_scale, rank, world_size) - if per_act_token else a1_scale[rank]) + if per_act_token + else a1_scale[rank], + ) torch.cuda.synchronize() @@ -209,35 +215,48 @@ def _pplx_moe( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_experts(a_full, w1_full, w2_full, - topk_weights, topk_ids) - pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, - w2_scale, topk_weights, topk_ids, - a1_scale, out_dtype, per_act_token, - per_out_ch, group_name) - - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) + torch_output = torch_experts( + a_full, w1_full, w2_full, topk_weights, topk_ids + ) + pplx_output = pplx_cutlass_moe( + pgi, + dp_size, + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a1_scale, + out_dtype, + per_act_token, + per_out_ch, + group_name, + ) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pplx_output.device + ) # Uncomment if more debugging is needed # print("PPLX OUT:", pplx_output) # print("TORCH OUT:", torch_output) - torch.testing.assert_close(pplx_output, - torch_output, - atol=0.05, - rtol=0) + torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) finally: if use_internode: nvshmem_finalize() @@ -250,12 +269,14 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) @requires_pplx def test_cutlass_moe_pplx( m: int, @@ -271,7 +292,6 @@ def test_cutlass_moe_pplx( current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0 @@ -281,22 +301,18 @@ def test_cutlass_moe_pplx( n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) + w1[expert], use_per_token_if_dynamic=per_out_ch + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) + w2[expert], use_per_token_if_dynamic=per_out_ch + ) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -305,19 +321,35 @@ def test_cutlass_moe_pplx( w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) world_size, dp_size = world_dp_size - a_scale1 = torch.randn( - (m if per_act_token else 1, 1), device="cuda", - dtype=torch.float32) / 10.0 + a_scale1 = ( + torch.randn( + (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32 + ) + / 10.0 + ) if not per_act_token: a_scale1 = a_scale1.repeat(world_size, 1) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, - w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, - dtype, a, w1_d, w2_d, per_act_token, per_out_ch, - use_internode) + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a_scale1, + dtype, + a, + w1_d, + w2_d, + per_act_token, + per_out_ch, + use_internode, + ) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f7a661b4bc7b..5237703d4389 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ + import itertools import textwrap import traceback @@ -14,9 +15,13 @@ try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False @@ -27,13 +32,12 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.platforms import current_platform from vllm.utils import round_up @@ -46,7 +50,7 @@ PPLX_COMBOS = [ # TODO: figure out why this fails, seems to be test problem - #(1, 128, 128), + # (1, 128, 128), (2, 128, 512), (3, 1024, 2048), (4, 128, 128), @@ -78,17 +82,16 @@ def torch_prepare( num_tokens, hidden_dim = a.shape topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) assert tokens_per_expert.numel() == num_experts if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) - b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), - dtype=a.dtype, - device=a.device) + b_a = torch.zeros( + (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device + ) token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) @@ -96,28 +99,29 @@ def torch_prepare( for j in range(topk): expert_id = topk_ids[token, j] idx = token_counts[expert_id] - b_a[expert_id, idx:idx + 1, :] = a[token, :] + b_a[expert_id, idx : idx + 1, :] = a[token, :] token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert -def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: +def torch_finalize( + b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor +) -> torch.Tensor: num_tokens = topk_ids.shape[0] num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, - dtype=torch.int, - device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx + - 1, :] * topk_weight[token, i] + out[token, :] = ( + out[token, :] + + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i] + ) expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -136,17 +140,18 @@ def torch_batched_moe( num_tokens, topk = topk_ids.shape _, max_num_tokens, K = b_a.shape assert num_experts == b_a.shape[0] and w2.shape[1] == K - out = torch.zeros((num_experts, max_num_tokens, K), - dtype=b_a.dtype, - device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), - dtype=b_a.dtype, - device=b_a.device) + out = torch.zeros( + (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device + ) + tmp = torch.empty( + (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device + ) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: torch.ops._C.silu_and_mul( - tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1) + ) out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) return torch_finalize(out, topk_weight, topk_ids) @@ -175,20 +180,16 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts(a, w1, w2, topk_weight, - topk_ids) # only for baseline + baseline_output = torch_experts( + a, w1, w2, topk_weight, topk_ids + ) # only for baseline torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = naive_batched_moe( - a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this + a, w1, w2, topk_weight, topk_ids + ) # pick torch_experts or this - torch.testing.assert_close(baseline_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(baseline_output, - batched_output, - atol=2e-2, - rtol=0) + torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) def create_pplx_prepare_finalize( @@ -206,7 +207,9 @@ def create_pplx_prepare_finalize( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) num_local_experts = rank_chunk(num_experts, 0, world_size) @@ -255,28 +258,31 @@ def rank_chunk(num: int, r: int, w: int) -> int: def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] -def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def maybe_chunk_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None: return chunk_by_rank(t, r, w) else: return t -def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def chunk_scales_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] else: return t -def chunk_scales(t: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + t: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: return t[start:end] else: @@ -339,8 +345,7 @@ def pplx_prepare_finalize( device=device, ) - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -364,8 +369,7 @@ def pplx_prepare_finalize( ), ) - b_a = dummy_work( - dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -399,15 +403,17 @@ def _pplx_prepare_finalize( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -415,22 +421,28 @@ def _pplx_prepare_finalize( a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) - torch_output = (a_rep.view(m, topk, k) * - topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum( - dim=1) - - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, - topk_ids, num_experts, quant_dtype, - block_shape, per_act_token_quant, - group_name) + torch_output = ( + a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype) + ).sum(dim=1) + + pplx_output = pplx_prepare_finalize( + pgi, + dp_size, + a, + topk_weight, + topk_ids, + num_experts, + quant_dtype, + block_shape, + per_act_token_quant, + group_name, + ) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pgi.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pgi.device + ) - torch.testing.assert_close(pplx_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) finally: if use_internode: nvshmem_finalize() @@ -479,9 +491,19 @@ def test_pplx_prepare_finalize_slow( a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, - topk, e, quant_dtype, block_shape, per_act_token_quant, - use_internode) + parallel_launch( + world_size, + _pplx_prepare_finalize, + dp_size, + a, + score, + topk, + e, + quant_dtype, + block_shape, + per_act_token_quant, + use_internode, + ) def pplx_moe( @@ -504,7 +526,6 @@ def pplx_moe( use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: - num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] topk = topk_ids.shape[1] @@ -557,41 +578,45 @@ def pplx_moe( # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. if use_compile: - _fused_experts = torch.compile(fused_experts, - backend='inductor', - fullgraph=True) + _fused_experts = torch.compile( + fused_experts, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a_chunk, 0) torch._dynamo.mark_dynamic(chunk_topk_weight, 0) torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, + global_num_experts=num_experts, + ) if use_cudagraphs: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, + global_num_experts=num_experts, + ) torch.cuda.synchronize() graph.replay() @@ -621,15 +646,17 @@ def _pplx_moe( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name m, k = a.shape @@ -647,8 +674,7 @@ def _pplx_moe( w1_s = w1_s.to(device) if w1_s is not None else None w2_s = w2_s.to(device) if w2_s is not None else None - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -708,17 +734,14 @@ def _pplx_moe( ) chunked_batch_output = chunk_by_rank( - batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) + batched_output, pgi.rank, pgi.world_size + ).to(pplx_output.device) - torch.testing.assert_close(batched_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) - torch.testing.assert_close(pplx_output, - chunked_batch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close( + pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2 + ) finally: if use_internode: nvshmem_finalize() @@ -773,14 +796,32 @@ def test_pplx_moe_slow( per_act_token_quant=per_act_token_quant, ) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, - w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, - use_internode) - + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1, + w2, + score, + topk, + e, + w1_s, + w2_s, + quant_dtype, + per_act_token_quant, + block_shape, + use_internode, + ) -def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - make_weights: bool, test_fn: Callable): +def _pplx_test_loop( + pgi: ProcessGroupInfo, + dp_size: int, + use_internode: bool, + make_weights: bool, + test_fn: Callable, +): def format_result(msg, ex=None): if ex is not None: x = str(ex) @@ -795,8 +836,9 @@ def format_result(msg, ex=None): print(f"PASSED {msg}") current_platform.seed_everything(7) - combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - [False, True], [None, [128, 128]]) + combos = itertools.product( + PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]] + ) exceptions = [] count = 0 for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: @@ -810,15 +852,14 @@ def format_result(msg, ex=None): use_fp8_w8a8 = False quant_dtype = None - test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " - f"dtype={dtype}, per_act_token={per_act_token_quant}, " - f"block_shape={block_shape}") + test_desc = ( + f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}" + ) - if not use_fp8_w8a8 and (per_act_token_quant - or block_shape is not None): - print( - f"{test_desc} - Skip quantization test for non-quantized type." - ) + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): + print(f"{test_desc} - Skip quantization test for non-quantized type.") continue if per_act_token_quant and block_shape is not None: @@ -865,10 +906,10 @@ def format_result(msg, ex=None): if len(exceptions) > 0: raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " - f"rank={pgi.rank}.") + f"rank={pgi.rank}." + ) else: - print(f"{count} of {count} tests passed in child process, " - f"rank={pgi.rank}.") + print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.") @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @@ -880,8 +921,14 @@ def test_pplx_prepare_finalize( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, _pplx_prepare_finalize) + parallel_launch( + world_size * dp_size, + _pplx_test_loop, + dp_size, + use_internode, + False, + _pplx_prepare_finalize, + ) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @@ -893,5 +940,6 @@ def test_pplx_moe( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, - _pplx_moe) + parallel_launch( + world_size, _pplx_test_loop, dp_size, use_internode, True, _pplx_moe + ) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index 1c51c530c193..d4724d749fc9 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -24,13 +24,14 @@ pytestmark = pytest.mark.skipif( not (current_platform.is_rocm() and aiter_available), - reason="AITER ops are only available on ROCm with aiter package installed") + reason="AITER ops are only available on ROCm with aiter package installed", +) def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) @@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): def test_rocm_aiter_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_grouped_topk) @@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): renormalize = True scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") - e_score_correction_bias = torch.randn((expert, ), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") + e_score_correction_bias = torch.randn( + (expert,), dtype=torch.bfloat16, device="cuda" + ) device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op - def biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights, topk_ids): + def biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights, topk_ids + ): return torch.ops.vllm.rocm_aiter_biased_grouped_topk( - gating_output, e_score_correction_bias, topk_weights, topk_ids, - num_expert_group, topk_group, renormalize, scale_factor) + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scale_factor, + ) # Verify the op's fake implementation torch.library.opcheck( @@ -84,51 +89,49 @@ def biased_grouped_topk_fn(gating_output, e_score_correction_bias, "num_expert_group": num_expert_group, "topk_group": topk_group, "need_renorm": renormalize, - "routed_scaling_factor": scale_factor + "routed_scaling_factor": scale_factor, }, - test_utils=("test_faketensor")) + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(biased_grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) - - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) - - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + compiled_fn = torch.compile( + biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) + + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights_original, topk_ids_original) - compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, - topk_ids_compiled) + biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original + ) + compiled_fn( + gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled + ) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) @@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility(): scoring_func = "softmax" scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func): return torch.ops.vllm.rocm_aiter_grouped_topk( - gating_output, topk_weights, topk_ids, num_expert_group, - topk_group, renormalize, scoring_func, scale_factor) + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scoring_func, + scale_factor, + ) # Verify the op's fake implementation - torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk, - (gating_output, topk_weights, topk_ids), - kwargs={ - "num_expert_group": num_expert_group, - "topk_group": topk_group, - "need_renorm": renormalize, - "scoring_func": scoring_func, - "routed_scaling_factor": scale_factor - }, - test_utils=("test_faketensor")) + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_grouped_topk, + (gating_output, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "scoring_func": scoring_func, + "routed_scaling_factor": scale_factor, + }, + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) - - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) - - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + compiled_fn = torch.compile( + grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) + + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original, - scoring_func) - compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, - scoring_func) + grouped_topk_fn( + gating_output, topk_weights_original, topk_ids_original, scoring_func + ) + compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 673a0aa36794..59ceb1e7e374 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,7 +5,8 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + silu_mul_fp8_quant_deep_gemm, +) from vllm.platforms import current_platform # (E, T, H, group_size, seed) @@ -28,16 +29,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): tokens_per_expert = torch.randint( low=0, high=T, - size=(E, ), + size=(E,), dtype=torch.int32, device="cuda", ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) + y_q, y_s = silu_mul_fp8_quant_deep_gemm( + y, tokens_per_expert, group_size=group_size, eps=1e-10 + ) # Reference implementation fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -54,9 +54,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), - dtype=torch.float32, - device="cuda") + ref_s = torch.empty((T, H // group_size), dtype=torch.float32, device="cuda") ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") for t in range(nt): data = merged[e, t] diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index dfd0f35c8da3..3bc1ac7b36c7 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -14,8 +14,7 @@ from vllm.platforms import current_platform if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -29,14 +28,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -86,17 +84,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = ops.scaled_fp8_quant( - act_out, use_per_token_if_dynamic=True) + act_out, use_per_token_if_dynamic=True + ) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -114,8 +112,10 @@ def setup_cuda(): SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -131,12 +131,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): # Generate int8 weights w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 - w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 - w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) # Generate scale for each column (per-column quantization) w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale @@ -160,7 +158,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index df89ad7e6da6..a063876c7546 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -5,15 +5,15 @@ import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import (per_block_cast_to_fp8, - per_block_cast_to_int8) +from tests.kernels.quant_utils import per_block_cast_to_fp8, per_block_cast_to_int8 from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + BatchedPrepareAndFinalize, + BatchedTritonExperts, + NaiveBatchedExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import round_up @@ -31,18 +31,20 @@ def triton_moe( per_act_token_quant=False, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_channel_quant=per_act_token_quant, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_channel_quant=per_act_token_quant, + use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, + block_shape=block_shape, + ) def batched_moe( @@ -62,10 +64,9 @@ def batched_moe( max_num_tokens = round_up(a.shape[0], 64) fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, @@ -75,15 +76,17 @@ def batched_moe( ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) def naive_batched_moe( @@ -103,10 +106,9 @@ def naive_batched_moe( max_num_tokens = round_up(a.shape[0], 64) fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), NaiveBatchedExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, @@ -116,19 +118,22 @@ def naive_batched_moe( ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) -def chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + scales: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: return scales @@ -151,13 +156,15 @@ def make_quantized_test_activations( a_scale = None if quant_dtype is not None: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" + assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, ( + "only fp8/int8 supported" + ) a_q = torch.zeros_like(a, dtype=quant_dtype) a_scale_l = [None] * E for e in range(E): a_q[e], a_scale_l[e] = moe_kernel_quantize_input( - a[e], None, quant_dtype, per_act_token_quant, block_shape) + a[e], None, quant_dtype, per_act_token_quant, block_shape + ) a_scale = torch.stack(a_scale_l) if not per_act_token_quant and block_shape is None: @@ -173,8 +180,9 @@ def moe_quantize_weights( per_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" + assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, ( + "only fp8/int8 supported" + ) if block_shape is not None: assert not per_token_quant @@ -185,10 +193,12 @@ def moe_quantize_weights( else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) else: w, w_s = ops.scaled_fp8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) return w, w_s @@ -209,7 +219,8 @@ def make_test_weight( w_s_l = [None] * e for idx in range(e): w_l[idx], w_s_l[idx] = moe_quantize_weights( - w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) + w_16[idx], None, quant_dtype, per_act_token_quant, block_shape + ) w = torch.stack(w_l) w_s = torch.stack(w_s_l) @@ -237,11 +248,19 @@ def make_test_weights( quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, - torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], +]: return ( - *make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), - *make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + *make_test_weight( + e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_act_token_quant + ), + *make_test_weight( + e, k, n, in_dtype, quant_dtype, block_shape, per_act_token_quant + ), ) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 6f43d1111c98..abd622ef5264 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -5,8 +5,7 @@ import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.platforms import current_platform from vllm.utils import round_up @@ -17,25 +16,31 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_dynamic_per_token_quant(x: torch.tensor, - quant_dtype: torch.dtype, - scale_ub: Optional[torch.tensor] = None) \ - -> tuple[torch.tensor, torch.tensor]: +def ref_dynamic_per_token_quant( + x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None +) -> tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: assert quant_dtype == FP8_DTYPE - qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ - else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.max - qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.min + qtype_traits = ( + torch.iinfo(quant_dtype) + if quant_dtype == torch.int8 + else torch.finfo(quant_dtype) + ) + qtype_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.max + ) + qtype_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.min + ) qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -56,15 +61,13 @@ def ref_dynamic_per_token_quant(x: torch.tensor, iscales = as_float32_tensor(s_1 / scales) torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) else: assert quant_dtype == FP8_DTYPE min_scaling_factor = s_1 / (qtype_max * s_512) scales = scales.clamp(min=min_scaling_factor) torch_out = as_float32_tensor(x) / scales - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) return torch_out, scales @@ -72,16 +75,20 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel -def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ - -> tuple[torch.tensor, torch.tensor]: - +def ref_dynamic_per_tensor_fp8_quant( + x: torch.tensor, +) -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.max - fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.min + fp8_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.max + ) + fp8_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.min + ) fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) @@ -92,9 +99,12 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ x_max = as_float32_tensor(x.abs().max()) ref_scale = x_max / fp8_max ref_iscale = one / ref_scale - ref_out = (as_float32_tensor(x) * ref_iscale).clamp( - fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) - return ref_out, ref_scale.view((1, )) + ref_out = ( + (as_float32_tensor(x) * ref_iscale) + .clamp(fp8_traits_min, fp8_traits_max) + .to(FP8_DTYPE) + ) + return ref_out, ref_scale.view((1,)) def native_w8a8_block_matmul( @@ -126,7 +136,7 @@ def native_w8a8_block_matmul( M = A.numel() // A.shape[-1] N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) + origin_C_shape = A.shape[:-1] + (N,) A = A.reshape(M, A.shape[-1]) As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n @@ -137,19 +147,19 @@ def native_w8a8_block_matmul( C_shape = (M, N) C = torch.zeros(C_shape, dtype=compute_type, device=A.device) - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] for i in range(k_tiles): for j in range(n_tiles): @@ -163,14 +173,14 @@ def native_w8a8_block_matmul( return C -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must " - "be divisible by `group_size`") + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) @@ -178,28 +188,25 @@ def native_per_token_group_quant_fp8(x, fp8_max = finfo.max x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / fp8_max x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s -def native_per_token_group_quant_int8(x, - group_size, - eps=1e-10, - dtype=torch.int8): +def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8): """Function to perform per-token-group quantization on an input tensor `x` using native torch. It converts the tensor values into int8 values and returns the quantized tensor along with the scaling factor used for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` must be divisible by `group_size`" + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" iinfo = torch.iinfo(dtype) @@ -208,13 +215,13 @@ def native_per_token_group_quant_int8(x, x_ = x.reshape(x.numel() // group_size, group_size) # Use float32 for scale calculation for stability - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / int8_max - x_q = (x_.to(torch.float32) / x_s).round().clamp( - min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = ( + (x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype) + ) # Round before clamping x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s @@ -229,9 +236,9 @@ def per_block_cast_to_fp8( block_m, block_n = block_shape assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -248,9 +255,9 @@ def per_block_cast_to_int8( block_m, block_n = block_shape assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -288,8 +295,9 @@ def batched_dequant( assert t.shape[0] == scale.shape[0] out = torch.empty_like(t, dtype=out_dtype) for e in range(t.shape[0]): - out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, - out_dtype) + out[e] = dequant( + t[e], scale[e], block_shape, per_act_token_quant, out_dtype + ) return out return t.to(out_dtype) @@ -313,15 +321,17 @@ def native_batched_masked_quant_matmul( num_tokens = num_expert_tokens_cpu[e] if A.dtype.itemsize == 1 and block_shape is not None: assert A_scale is not None and B_scale is not None - tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e], - block_shape, C.dtype) + tmp = native_w8a8_block_matmul( + A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype + ) C[e, :num_tokens, :] = tmp[:num_tokens, :] elif A.dtype.itemsize == 1 and block_shape is None: assert A_scale is not None and B_scale is not None A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant) B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant) - C[e, :num_tokens, :] = ( - A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype) + C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to( + C.dtype + ) else: assert A_scale is None assert B_scale is None diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 1095975ab2b4..db7feea10a5f 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -7,8 +7,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): @@ -21,12 +22,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def dequantize_nvfp4_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): +def dequantize_nvfp4_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 diff --git a/tests/kernels/quantization/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py index 3de9cb364468..de0cd0874746 100644 --- a/tests/kernels/quantization/test_allspark_gemm.py +++ b/tests/kernels/quantization/test_allspark_gemm.py @@ -2,28 +2,29 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch - from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - ALLSPARK_AMPERE_N_ALIGN) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + ALLSPARK_AMPERE_K_ALIGN, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_AMPERE_N_ALIGN, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -def is_gptq_allspark_supported(min_capability: int, - max_capability: int) -> bool: +def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool: if not current_platform.is_cuda(): return False capability = current_platform.get_device_capability() assert capability is not None - return capability.to_int() >= min_capability \ - and capability.to_int() <= max_capability + return ( + capability.to_int() >= min_capability and capability.to_int() <= max_capability + ) MNK_FACTORS = [ @@ -43,7 +44,8 @@ def is_gptq_allspark_supported(min_capability: int, def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): @@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.skipif( not is_gptq_allspark_supported(80, 89), - reason="AllSpark Ampere kernel is not supported on this GPU type.") + reason="AllSpark Ampere kernel is not supported on this GPU type.", +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("group_size", [-1]) @pytest.mark.parametrize("has_zp", HAS_ZP_OPTS) @@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): weight = rand_data((k, n), dtype=dtype) # Quantize (and apply act_order if provided) - w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128, - group_size, has_zp) + w_ref, qw, s, zp = quantize_weights( + weight, scalar_types.uint8b128, group_size, has_zp + ) qw = qw.to(torch.uint8) if has_zp: @@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): n_32align = (n + 32 - 1) // 32 * 32 - qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( - qw, s, zp, has_zp) - opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order, - (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, - n_32align)) - - opcheck(torch.ops._C.allspark_w8a16_gemm, - (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, - sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, - n, group_size, sm_count, sm_version, - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - has_zp, True) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp) + opcheck( + torch.ops._C.rearrange_kn_weight_as_n32k16_order, + (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align), + ) + + opcheck( + torch.ops._C.allspark_w8a16_gemm, + ( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + output = ops.allspark_w8a16_gemm( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ) output_ref = torch.matmul(input, w_ref) torch.cuda.synchronize() diff --git a/tests/kernels/quantization/test_aqlm.py b/tests/kernels/quantization/test_aqlm.py index 427db3e60292..20eddf741781 100644 --- a/tests/kernels/quantization/test_aqlm.py +++ b/tests/kernels/quantization/test_aqlm.py @@ -2,39 +2,36 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops # noqa: F401 def test_aqlm_dequant_opcheck(): - codes = torch.randint(-32768, - 32767, (22016, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((2, 65536, 1, 8), - device='cuda', - dtype=torch.float16) + codes = torch.randint( + -32768, 32767, (22016, 512, 1), device="cuda", dtype=torch.int16 + ) + codebooks = torch.rand((2, 65536, 1, 8), device="cuda", dtype=torch.float16) codebook_partition_sizes = [11008, 11008] - opcheck(torch.ops._C.aqlm_dequant, - (codes, codebooks, codebook_partition_sizes)) + opcheck(torch.ops._C.aqlm_dequant, (codes, codebooks, codebook_partition_sizes)) def test_aqlm_gemm_opcheck(): - input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) - codes = torch.randint(-32768, - 32767, (12288, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((3, 65536, 1, 8), - device='cuda', - dtype=torch.float16) - scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) + input = torch.rand((4, 4096), device="cuda", dtype=torch.float16) + codes = torch.randint( + -32768, 32767, (12288, 512, 1), device="cuda", dtype=torch.int16 + ) + codebooks = torch.rand((3, 65536, 1, 8), device="cuda", dtype=torch.float16) + scales = torch.rand((12288, 1, 1, 1), device="cuda", dtype=torch.float16) codebook_partition_sizes = [4096, 4096, 4096] bias = None - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, None)) - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, bias)) + opcheck( + torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, None), + ) + opcheck( + torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, bias), + ) diff --git a/tests/kernels/quantization/test_awq.py b/tests/kernels/quantization/test_awq.py index bc0868123d82..94be8805a0f3 100644 --- a/tests/kernels/quantization/test_awq.py +++ b/tests/kernels/quantization/test_awq.py @@ -3,45 +3,47 @@ import pytest import torch - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops # noqa: F401 -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_dequantize"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) - zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16) + zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32) split_k_iters = 0 thx = 0 thy = 0 - opcheck(torch.ops._C.awq_dequantize, - (qweight, scales, zeros, split_k_iters, thx, thy)) + opcheck( + torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy), + ) @pytest.mark.skip(reason="Not working; needs investigation.") -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_gemm"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.randint(-2000000000, - 2000000000, (64, 256), - device='cuda', - dtype=torch.int32) - qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + input = torch.rand((2, 8192), device="cuda", dtype=torch.float16) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.randint( + -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 + ) + qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16) split_k_iters = 8 - opcheck(torch.ops._C.awq_gemm, - (input, qweight, qzeros, scales, split_k_iters)) + opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 96797e85bd12..b74c7b84120f 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -4,11 +4,15 @@ Run `pytest tests/kernels/test_awq_triton.py`. """ + import pytest import torch from vllm.model_executor.layers.quantization.awq_triton import ( - AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + AWQ_TRITON_SUPPORTED_GROUP_SIZES, + awq_dequantize_triton, + awq_gemm_triton, +) from vllm.platforms import current_platform device = "cuda" @@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor): # qweights - [R , C // 8], int32 # scales - [R // G, C ], float16 # zeros - [R // G, C // 8], int32 -def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, - qzeros: torch.Tensor, - group_size: int) -> torch.Tensor: - +def awq_dequantize_torch( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int +) -> torch.Tensor: if group_size == -1: group_size = qweight.shape[0] bits = 4 shifts = torch.arange(0, 32, bits, device=qzeros.device) - iweights = torch.bitwise_right_shift(qweight[:, :, None], - shifts[None, None, :]).to(torch.int8) + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) iweights = iweights.view(iweights.shape[0], -1) - zeros = torch.bitwise_right_shift(qzeros[:, :, None], - shifts[None, None, :]).to(torch.int8) + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) zeros = zeros.view(qzeros.shape[0], -1) zeros = reverse_awq_order(zeros) @@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, @pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) def test_dequantize(qweight_rows, qweight_cols, group_size): - if group_size == -1: group_size = qweight_rows @@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): current_platform.seed_everything(0) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - dtype=qweight_dtype, - device=device) - scales = torch.rand(scales_rows, - scales_cols, - dtype=scales_dtype, - device=device) - zeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (zeros_rows, zeros_cols), - dtype=zeros_dtype, - device=device) + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device, + ) + scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device) + zeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device, + ) iweights_triton = awq_dequantize_triton(qweight, scales, zeros) - assert (not torch.any(torch.isinf(iweights_triton)) - and not torch.any(torch.isnan(iweights_triton))) + assert not torch.any(torch.isinf(iweights_triton)) and not torch.any( + torch.isnan(iweights_triton) + ) iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) @@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("splitK", [1, 8]) def test_gemm(N, K, M, splitK, group_size): - if group_size == -1: group_size = K @@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size): current_platform.seed_everything(0) - input = torch.rand((input_rows, input_cols), - dtype=input_dtype, - device=device) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - device=device) - qzeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (qzeros_rows, qzeros_cols), - device=device) - scales = torch.rand((scales_rows, scales_cols), - dtype=scales_dtype, - device=device) - - output_triton = awq_gemm_triton(input, qweight, scales, qzeros, - split_k_iters) - - assert (not torch.any(torch.isinf(output_triton)) - and not torch.any(torch.isnan(output_triton))) + input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device) + qweight = torch.randint( + 0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device + ) + qzeros = torch.randint( + 0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device + ) + scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device) + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + + assert not torch.any(torch.isinf(output_triton)) and not torch.any( + torch.isnan(output_triton) + ) dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) output_torch = torch.matmul(input, dequantized_weights) - assert (not torch.any(torch.isinf(output_torch)) - and not torch.any(torch.isnan(output_torch))) + assert not torch.any(torch.isinf(output_torch)) and not torch.any( + torch.isnan(output_torch) + ) - torch.testing.assert_close(output_triton.cpu(), - output_torch.cpu(), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close( + output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1 + ) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 26aa8d652e63..b33ee5bd8a26 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -6,20 +6,23 @@ import pytest import torch +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul) + get_col_major_tma_aligned_tensor, + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -50,7 +53,8 @@ def setup_cuda(): @pytest.mark.parametrize( "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), +) @torch.inference_mode() def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): torch.manual_seed(seed) @@ -59,15 +63,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) out, scale = per_token_group_quant_fp8(x, group_size) - assert torch.allclose(out.to(torch.float32), - ref_out.to(torch.float32), - rtol=0.15) + assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) assert torch.allclose(scale, ref_scale) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -88,21 +91,20 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@pytest.mark.skipif(not has_deep_gemm(), - reason="DeepGemm kernels not available.") + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -122,20 +124,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) + out = torch.zeros((M, N), device="cuda", dtype=out_dtype) - assert As_fp8.shape == (M, (K + 127) // - 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + assert As_fp8.shape == (M, (K + 127) // 128), ( + f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + ) fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index fac82cf9c8b5..9ca05563a194 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -6,16 +6,16 @@ import pytest import torch - from tests.kernels.quant_utils import native_w8a8_block_matmul + from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( - w8a8_block_int8_matmul) + w8a8_block_int8_matmul, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -36,8 +36,10 @@ def setup_cuda(): torch.set_default_device("cuda") -@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py index 878f66647e19..e88f82d3e5cb 100644 --- a/tests/kernels/quantization/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -7,16 +7,15 @@ import pytest import torch - from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8 + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) + sparse_cutlass_supported, +) from vllm.platforms import current_platform -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -40,9 +39,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -55,32 +52,31 @@ def prune_to_2_4(tensor): # This function checks that applying an identity matrix multiplication # to the compressed weights yields the original uncompressed weights. -def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, - b_compressed: torch.Tensor, - b_metadata: torch.Tensor): - +def check_compress_decompress_invariance( + dtype: torch.dtype, + b: torch.Tensor, + b_compressed: torch.Tensor, + b_metadata: torch.Tensor, +): # For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the # same dtype as its inputs. This line addresses that constraint while # arbitrarily using bfloat16 for the int8/fp8 cases. out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16 - eye = torch.eye(b.shape[0], device='cuda', dtype=dtype) - eye_scale = torch.ones(1, device='cuda', dtype=torch.float32) - b_decomp = ops.cutlass_scaled_sparse_mm(eye, - b_compressed, - b_metadata, - eye_scale, - eye_scale, - out_dtype=out_dtype) + eye = torch.eye(b.shape[0], device="cuda", dtype=dtype) + eye_scale = torch.ones(1, device="cuda", dtype=torch.float32) + b_decomp = ops.cutlass_scaled_sparse_mm( + eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype + ) torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp) def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int + dtype: torch.dtype, m: int, n: int, k: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') - b = torch.randn((n, k), device='cuda').t() + a = torch.randn((m, k), device="cuda") + b = torch.randn((n, k), device="cuda").t() if dtype == torch.int8: # ensure A and B aren't all zeros after rounding @@ -107,32 +103,25 @@ def make_rand_sparse_tensors( return b_compressed, e, a, b -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) # Test working with a subset of A and B for sparse matmul def test_cutlass_sparse_subset(): - big_m = 1024 m, n, k = 512, 512, 512 # Create tensors - b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, - big_m, n, k) + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k) a = whole_a[0:m, 0:k] scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16 + ) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) @@ -161,105 +150,87 @@ def test_cutlass_sparse_subset(): # Test working with a subset of A and B for sparse matmul -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype], - use_bias: bool): - +def test_cutlass_sparse_gemm( + m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32) scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32) - bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, k, n", MNK_FACTORS) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool): - # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m,k,n", MNK_FACTORS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - +def test_cutlass_sparse_int8_gemm( + m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None - - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) - - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None + + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) + + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index c4d349f1a5a0..dd1b02083c8d 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -4,12 +4,13 @@ Run `pytest tests/kernels/test_cutlass.py`. """ + import random import pytest import torch - from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8 + from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import cdiv @@ -36,9 +37,7 @@ (512, 24576, 128), ] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # -1 means full extent in that dimension TENSORWISE_GROUP_SHAPE = (-1, -1) @@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape): def scale_shape(shape, group_shape): assert len(shape) == len(group_shape) group_shape = group_scale_helper(shape, group_shape) - return tuple( - cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) - - -def cutlass_fp8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + + +def cutlass_fp8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_fp8(torch.randn((m, k), device=device)) @@ -80,8 +80,8 @@ def cutlass_fp8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) # make scales M-major for blockwise quant, doesn't affect 1D scales scale_a = scale_a.t().contiguous().t() @@ -89,7 +89,7 @@ def cutlass_fp8_gemm_helper(m: int, scale_b = scale_b.t().contiguous().t() if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 else: bias = None @@ -98,18 +98,19 @@ def cutlass_fp8_gemm_helper(m: int, torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) -def cutlass_int8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): +def cutlass_int8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_int8(torch.randn((m, k), device=device) * 5) @@ -118,11 +119,11 @@ def cutlass_int8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 else: bias = None @@ -131,145 +132,192 @@ def cutlass_int8_gemm_helper(m: int, torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, - a_scale_group_shape, - b_scale_group_shape, use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0: return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return if m % 4 != 0 and current_platform.has_device_capability(100): return - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +def test_cutlass_int8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_int8_gemm_helper( + m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +def test_cutlass_int8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape, - b_scale_group_shape, use_bias, torch.bfloat16, - device) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + torch.bfloat16, + device, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=torch.bfloat16, - device=device) +def test_cutlass_int8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=torch.bfloat16, + device=device, + ) # For the following two tests: @@ -277,32 +325,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, # of a large power of two. In any case, the kernel will have a naive fallback # when N and K are not divisible by 16. But M is the number of tokens and the # kernel must handle any M thrown at it. -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_fp8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +def test_cutlass_int8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_int8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) @pytest.mark.parametrize("m", [32, 64, 128]) @@ -310,8 +368,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, @pytest.mark.parametrize("k", [64, 128, 256]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.skip -def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, - out_dtype: torch.dtype): +def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype): # Currently, the test is failing because folding azp into # 16-bit bias loses too much precision scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 @@ -328,7 +385,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, b_dq = scale_b * bq_f32 - azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding @@ -340,18 +397,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, J = torch.ones((1, k), device="cuda", dtype=torch.float32) azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype) assert azp_bias.shape == (1, n) - assert azp_bias[0, :].shape == (n, ) - - baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * ( - (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to( - dtype=out_dtype, device='cuda') - - out = ops.cutlass_scaled_mm(aq_i8, - bq_i8, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=azp_bias[0, :]) + assert azp_bias[0, :].shape == (n,) + + baseline_q = ( + scale_a.to(device="cpu") + * scale_b.to(device="cpu") + * ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu")) + ).to(dtype=out_dtype, device="cuda") + + out = ops.cutlass_scaled_mm( + aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :] + ) torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0) torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0) @@ -362,8 +418,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("azp_per_token", [True, False]) -def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, - use_bias: bool, azp_per_token: bool): +def test_cutlass_int8_azp( + m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool +): m_azp = m if azp_per_token else 1 scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10 @@ -377,16 +434,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, bq_f32 = bq_i8.to(dtype=torch.float32) b_dq = scale_b * bq_f32 - azp_a = torch.rand( - (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32) - torch.testing.assert_close(a_dq, - scale_a * aq_f32 - azp_a, - rtol=1e-4, - atol=1e-3) + torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3) if use_bias: bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5 @@ -396,8 +449,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype) # int32 mm not supported on CUDA - a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu') - cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda') + a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu") + cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda") baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype) # Hadamard is just the sum of the cols @@ -406,14 +459,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, func_bias = bias if use_bias else None if azp_per_token: - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_adj_i32, azp_i32, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias + ) else: azp_with_adj_i32 = azp_i32 * azp_adj_i32 - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_with_adj_i32, None, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias + ) # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4% # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05% @@ -423,13 +476,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol) if azp_per_token: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias), + ) else: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias), + ) # Test working with a subset of A and B @@ -445,23 +500,14 @@ def test_cutlass_subset(): scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) # Test to make sure cuda graphs work class CutlassLayer(torch.nn.Module): - def __init__(self, b, scale_a, scale_b, out_dtype): super().__init__() self.b = b @@ -470,8 +516,9 @@ def __init__(self, b, scale_a, scale_b, out_dtype): self.out_dtype = out_dtype def forward(self, a): - return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b, - self.out_dtype) + return ops.cutlass_scaled_mm( + a, self.b, self.scale_a, self.scale_b, self.out_dtype + ) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -485,10 +532,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): m_a_scales = m if per_act_token else 1 n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn( - (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) - scale_b = (torch.randn( - (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) + scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10 # Construct a trivial model with a single layer that calls a CUTLASS kernel model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) @@ -502,13 +547,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): out.zero_() g.replay() - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + baseline = torch.mm( + scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32) + ).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) def test_cutlass_support_opcheck(): - opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,)) @pytest.mark.parametrize("num_experts", [8, 64]) @@ -517,11 +563,13 @@ def test_cutlass_support_opcheck(): @pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) +def test_cutlass_fp8_group_gemm( + num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Device and dtype setup device = "cuda" out_dtype = torch.half @@ -533,13 +581,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_scales_tensors = [] baseline_tensors = [] - expert_offsets = torch.zeros((num_experts + 1), - device=device, - dtype=torch.int32) + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) - problem_sizes = torch.zeros((num_experts, 3), - device=device, - dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) if not per_act_token: one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32) @@ -568,77 +612,78 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_tensors.append(b_g) # Set up A/B scales - scale_b = torch.randn((1, n_b_scales), - device=device, - dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32) b_scales_tensors.append(scale_b) if per_act_token: - scale_a = torch.randn((m_a_scales, 1), - device=device, - dtype=torch.float32) + scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32) a_scales_tensors.append(scale_a) else: scale_a = one_scale_a # Compute baseline result for this group - baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, - None) + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) baseline_tensors.append(baseline_g) - a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), - device=device, - dtype=torch.float8_e4m3fn) - b_tensors_stacked = torch.empty((num_experts, n_g, k_g), - device=device, - dtype=torch.float8_e4m3fn) + a_tensors_stacked = torch.empty( + (expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_tensors_stacked = torch.empty( + (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) for g in range(num_experts): - a_tensors_stacked[expert_offsets[g]:expert_offsets[g + - 1]] = a_tensors[g] + a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] b_tensors_stacked[g] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.transpose(1, 2) if per_act_token: a_scales_tensors_stacked = torch.empty( - (expert_offsets[num_experts], 1), - device=device, - dtype=torch.float32) + (expert_offsets[num_experts], 1), device=device, dtype=torch.float32 + ) for g in range(num_experts): - a_scales_tensors_stacked[ - expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] + a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = ( + a_scales_tensors[g] + ) else: a_scales_tensors_stacked = one_scale_a - b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales), - device=device, - dtype=torch.float32) + b_scales_tensors_stacked = torch.empty( + (num_experts, n_b_scales), device=device, dtype=torch.float32 + ) for g in range(num_experts): b_scales_tensors_stacked[g] = b_scales_tensors[g] - out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), - device=device, - dtype=out_dtype) - - ab_strides = torch.full((num_experts, ), - a_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) - c_strides = torch.full((num_experts, ), - out_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) - - ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, - b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, expert_offsets[:-1], - problem_sizes, ab_strides, ab_strides, c_strides, - per_act_token, per_out_ch) + out_tensors_stacked = torch.zeros( + (expert_offsets[num_experts], n_g), device=device, dtype=out_dtype + ) + + ab_strides = torch.full( + (num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) + c_strides = torch.full( + (num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) + + ops.cutlass_moe_mm( + out_tensors_stacked, + a_tensors_stacked, + b_tensors_stacked, + a_scales_tensors_stacked, + b_scales_tensors_stacked, + expert_offsets[:-1], + problem_sizes, + ab_strides, + ab_strides, + c_strides, + per_act_token, + per_out_ch, + ) # Validate each group's result against the baseline for g in range(num_experts): baseline = baseline_tensors[g] - c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] print(baseline) print(c) print("*") diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index 0a3edd4ddc16..0372fe285cf2 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -3,40 +3,56 @@ import pytest import torch +from tests.kernels.quant_utils import ( + FP8_DTYPE, + ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant, +) +from tests.kernels.utils import opcheck import vllm._custom_ops as ops -from tests.kernels.quant_utils import (FP8_DTYPE, - ref_dynamic_per_tensor_fp8_quant, - ref_dynamic_per_token_quant) -from tests.kernels.utils import opcheck from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing +HIDDEN_SIZES = [ + 1, + 2, + 3, + 4, + 16, + 67, + 768, + 2048, + 5120, + 5137, + 8192, + 8193, +] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SCALE_UBS = [True, False] SEEDS = [0] -def opcheck_fp8_quant(output, - input, - scale=None, - scale_ub=None, - use_per_token_if_dynamic=False): +def opcheck_fp8_quant( + output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False +): if scale is not None: opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) elif use_per_token_if_dynamic: - scale = torch.empty((input.shape[0], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, - (output, input, scale, scale_ub)) + scale = torch.empty( + (input.shape[0], 1), device=input.device, dtype=torch.float32 + ) + opcheck( + torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub), + ) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32, + ) opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) @@ -46,30 +62,29 @@ def opcheck_fp8_quant(output, @pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, scale_ub: bool, - seed: int) -> None: +def test_dynamic_per_token_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int +) -> None: current_platform.seed_everything(seed) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") + 1e-6 # avoid nans + x = ( + torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 + ) # avoid nans - scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ - if scale_ub else None + scale_ub = ( + torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None + ) ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) - ops_out, ops_scales = ops.scaled_fp8_quant(x, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.scaled_fp8_quant( + x, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) torch.testing.assert_close(ref_scales, ops_scales) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) - opcheck_fp8_quant(ops_out, - x, - None, - scale_ub, - use_per_token_if_dynamic=True) + opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -77,8 +92,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_per_tensor_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -87,8 +103,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, ops_out, ops_scale = ops.scaled_fp8_quant(x) torch.testing.assert_close(ref_scale, ops_scale) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck_fp8_quant(ops_out, x) diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index 07651fef39bf..eac7f4d4f250 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -4,8 +4,8 @@ import gguf import pytest import torch - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops # noqa: F401 @@ -13,33 +13,42 @@ def test_ggml_opcheck(quant_type): block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] shape = [256, 1152] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) m = qweight.shape[0] n = qweight.shape[1] // type_size * block_size - opcheck(torch.ops._C.ggml_dequantize, - (qweight, quant_type, m, n, torch.float16)) + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16)) - x = torch.rand((m, 512), device='cuda', dtype=torch.float16) - opcheck(torch.ops._C.ggml_mul_mat_a8, - (qweight, x, quant_type, qweight.shape[0])) - opcheck(torch.ops._C.ggml_mul_mat_vec_a8, - (qweight, x, quant_type, qweight.shape[0])) + x = torch.rand((m, 512), device="cuda", dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0])) + opcheck( + torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0]) + ) shape = [256, 1024, 336] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) - x = torch.rand((1, 1024), device='cuda', dtype=torch.float16) - sorted_token_ids = torch.arange(776, device='cuda') - expert_ids = torch.randint(0, 256, (194, ), device='cuda') - num_tokens_post_padded = torch.tensor([1], - dtype=torch.int64, - device='cuda') + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) + x = torch.rand((1, 1024), device="cuda", dtype=torch.float16) + sorted_token_ids = torch.arange(776, device="cuda") + expert_ids = torch.randint(0, 256, (194,), device="cuda") + num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda") - opcheck(torch.ops._C.ggml_moe_a8, - (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, - quant_type, qweight.shape[0], 1, x.shape[0])) - - topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + opcheck( + torch.ops._C.ggml_moe_a8, + ( + x, + qweight, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + quant_type, + qweight.shape[0], + 1, + x.shape[0], + ), + ) + + topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32) opcheck( torch.ops._C.ggml_moe_a8_vec, - (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]), + ) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 436d5cb64021..0988ba01759f 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -18,8 +18,8 @@ def get_gguf_sample_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -27,8 +27,8 @@ def get_gguf_sample_tensors( def get_gguf_MoE_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE_MOE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -68,17 +68,20 @@ def get_gguf_MoE_tensors( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_dequantize(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_dequantize( + hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType +): tensors = get_gguf_sample_tensors(hidden_size, quant_type) for tensor in tensors: shape_str = tensor.name.split("_")[-1] shape = map(int, shape_str.split("x")) - ref_output = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) - output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), - quant_type, *list(shape), dtype) + ref_output = torch.tensor( + dequantize(tensor.data, quant_type), device="cuda" + ).to(dtype) + output = ops.ggml_dequantize( + torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype + ) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) @@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_mmvq(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") - output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, - qweight.shape[0]).to(dtype) + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to( + dtype + ) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) @@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, GGMLQuantizationType.Q4_0, GGMLQuantizationType.Q5_0, GGMLQuantizationType.Q8_0, - ]) + ], +) @torch.inference_mode() -def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmq( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, +): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") @@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, # bfloat16 tends to accumulate and can greatly inflate rtol # since outputs are also very close to 0 rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1} - torch.testing.assert_close(output, - ref_output, - atol=atols[dtype], - rtol=rtols[dtype]) + torch.testing.assert_close( + output, ref_output, atol=atols[dtype], rtol=rtols[dtype] + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType, top_k: int): +def test_moe( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, + top_k: int, +): current_platform.seed_everything(0) H, E = 1024, 256 x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, - E, (num_tokens, top_k), - device="cuda", - dtype=torch.int32) + topk_ids = torch.randint( + 0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32 + ) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) w13 = tensors[0] w2 = tensors[1] - w13_dequant = torch.tensor(dequantize(w13.data, quant_type), - device="cuda").to(dtype) - - w2_dequant = torch.tensor(dequantize(w2.data, quant_type), - device="cuda").to(dtype) - - output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), - torch.tensor(w2.data, - device="cuda"), topk_weights, - topk_ids, quant_type, quant_type, "silu") - - ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, - topk_ids).reshape(output.shape) + w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to( + dtype + ) + + w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype) + + output = _fused_moe_gguf( + x, + torch.tensor(w13.data, device="cuda"), + torch.tensor(w2.data, device="cuda"), + topk_weights, + topk_ids, + quant_type, + quant_type, + "silu", + ) + + ref_output = fused_experts( + x, w13_dequant, w2_dequant, topk_weights, topk_ids + ).reshape(output.shape) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_gptq.py b/tests/kernels/quantization/test_gptq.py index 7fb57a1576bd..32782be3cee4 100644 --- a/tests/kernels/quantization/test_gptq.py +++ b/tests/kernels/quantization/test_gptq.py @@ -2,31 +2,28 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops # noqa: F401 def test_gptq_shuffle_opcheck(): - weight = torch.randint(-2000000, - 2000000, (1792, 4096), - device='cuda', - dtype=torch.int32) - perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + weight = torch.randint( + -2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32 + ) + perm = torch.empty((0,), device="cuda", dtype=torch.int32) bit = 4 opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) def test_gptq_gemm_opcheck(): - a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) - weight = torch.randint(-2000000, - 2000000, (512, 6144), - device='cuda', - dtype=torch.int32) - zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) - scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) - idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + a = torch.rand((240, 4096), device="cuda", dtype=torch.float16) + weight = torch.randint( + -2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32 + ) + zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32) + scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16) + idx = torch.empty((0,), device="cuda", dtype=torch.int32) use_exllama = True bit = 4 - opcheck(torch.ops._C.gptq_gemm, - (a, weight, zeros, scales, idx, use_exllama, bit)) + opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/quantization/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py index dc5fecbf4ccc..3af4da6d7e05 100644 --- a/tests/kernels/quantization/test_int8_kernel.py +++ b/tests/kernels/quantization/test_int8_kernel.py @@ -10,12 +10,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_quant_int8) + per_token_quant_int8, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): @@ -25,14 +25,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -66,25 +65,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): mask = topk_ids == i if mask.sum(): # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - output_dtype=a.dtype) + inter_out = native_w8a8_per_token_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype + ) # Activation function act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = per_token_quant_int8(act_out) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -102,8 +98,10 @@ def setup_cuda(): SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -144,7 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index 5a37b976db9e..7b66077ca13c 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -3,9 +3,9 @@ import pytest import torch - from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck + from vllm._custom_ops import scaled_int8_quant from vllm.platforms import current_platform @@ -19,26 +19,24 @@ def opcheck_int8_quant_static(output, input, scale, azp=None): if azp is None: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None)) else: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, azp)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp)) def opcheck_int8_quant_dynamic(output, input, symmetric=True): - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) if symmetric: - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None)) else: - azp = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.int32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, azp)) + azp = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32, + ) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -46,8 +44,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -69,30 +68,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_azp_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( - torch.int32) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32) - torch_out = ((x / scales).round() + azps).clamp( - int8_traits.min, int8_traits.max).to(torch.int8) - assert torch_out.min() >= int8_traits.min and torch_out.max( - ) <= int8_traits.max + torch_out = ( + ((x / scales).round() + azps) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) + assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False) - if (not torch.allclose(scales_out, scales)): + if not torch.allclose(scales_out, scales): print(torch.argmax(torch.abs(scales_out - scales))) torch.testing.assert_close(scales_out, scales) # big atol to account for rounding errors @@ -109,17 +109,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() -def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float) -> None: +def test_static_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") - out1 = (x / scale_arg).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + (x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) + ) out2, scale2, _ = scaled_int8_quant(x, scale_arg) assert scale2 is scale_arg @@ -136,24 +137,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("scale", SCALE) @pytest.mark.parametrize("azp", [-255, 54]) @torch.inference_mode() -def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float, azp: int) -> None: +def test_static_scaled_int8_azp_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + scale: float, + azp: int, +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 - out1 = ((x / scale).round() + azp).clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + ((x / scale).round() + azp) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") - out2, scale2, azp2 = scaled_int8_quant(x, - scale_arg, - azp_arg, - symmetric=False) + out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False) assert scale2 is scale_arg assert azp2 is azp_arg @@ -173,10 +178,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: int32_traits = torch.iinfo(torch.int32) val = float(int32_traits.max if is_max else int32_traits.min) - x_vals = [[ - nextafter(val, inf), val + 1, val, val - 1, - nextafter(val, -inf) - ]] + x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]] x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") # The calculation in the kernel is: cast(cast(x / scale) + azp) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a7cb2a4e7f21..72d9d3e0e404 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -11,19 +11,20 @@ import pytest import torch - from tests.kernels.utils import opcheck + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - query_machete_supported_group_sizes) + query_machete_supported_group_sizes, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # unit tests to a common utility function. Currently the use of @@ -76,46 +77,63 @@ class Tensors: # Ch Scales Type, Tok Scales Type) # NOTE: None "Scale Type" means the act type is floating point # None "Output Type" means the output type is the same as the act type -TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], - Optional[torch.dtype], bool] +TestTypeTuple = tuple[ + list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool +] TEST_TYPES = [ # GPTQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16] + ), # AWQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=a_type, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4, scalar_types.uint8] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16] + ), # QQQ style - *(TypeConfig(act_type=torch.int8, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), - *(TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), + *( + TypeConfig( + act_type=torch.int8, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float, + ) + for group_scale_type in [None, torch.float16] + ), + *( + TypeConfig( + act_type=torch.float8_e4m3fn, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float, + ) + for group_scale_type in [None, torch.float16] + ), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel @@ -137,17 +155,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): return zps if zps is None else -1 * s * (zps.to(s.dtype)) -def group_size_valid(shape: tuple[int, int, int], - group_size: Optional[int]) -> bool: +def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: return group_size is None or group_size == -1 or shape[2] % group_size == 0 -def machete_quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def machete_quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -156,7 +175,8 @@ def machete_quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) w_q = w_q.t().contiguous().t() # convert to col major @@ -167,15 +187,18 @@ def machete_quantize_and_pack(atype: torch.dtype, return w_ref, w_q_machete, w_s, w_zp -def create_test_tensors(shape: tuple[int, int, int], - types: TypeConfig, - group_size: Optional[int], - subset_stride_factor: Optional[int] = None) -> Tensors: +def create_test_tensors( + shape: tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None, +) -> Tensors: m, n, k = shape factor = subset_stride_factor or 1 - print("create_test_tensors, shape:", shape, "types:", types, "group_size:", - group_size) + print( + "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size + ) a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) @@ -190,8 +213,13 @@ def create_test_tensors(shape: tuple[int, int, int], w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -200,35 +228,47 @@ def create_test_tensors(shape: tuple[int, int, int], a_ref = a.to(torch.float32) w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) - return Tensors(w_ref=w_ref, - a_ref=a_ref, - a=a, - w_q=w_q_packed, - w_g_s=w_s, - w_g_zp=maybe_convert_zeropoints(w_zp, w_s), - w_ch_s=w_ch_s, - w_tok_s=w_tok_s) + return Tensors( + w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) # None stype means scales use the same dtype as a -def machete_mm_test_helper(types: TypeConfig, - tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None): +def machete_mm_test_helper( + types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None, +): output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) output_ref_type = output_ref.dtype if tensors.w_ch_s is not None: - output_ref = (output_ref.to(tensors.w_ch_s.dtype) * - tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0) + ).to(output_ref_type) if tensors.w_tok_s is not None: - output_ref = (output_ref.to(tensors.w_tok_s.dtype) * - tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1) + ).to(output_ref_type) output = ops.machete_mm( a=tensors.a, @@ -249,23 +289,23 @@ def machete_mm_test_helper(types: TypeConfig, # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies # zeropoints (after scales) causes noise around 0 - atol = 1 if tensors.w_g_zp is not None\ + atol = ( + 1 + if tensors.w_g_zp is not None else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + ) rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 - torch.testing.assert_close(output, - output_ref.to(output.dtype), - rtol=rtol, - atol=atol) + torch.testing.assert_close( + output, output_ref.to(output.dtype), rtol=rtol, atol=atol + ) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_all_schedules(shape, types: TypeConfig): - group_sizes: list[Optional[int]] = [] if types.group_scale_type is None: group_sizes = [None] @@ -279,20 +319,20 @@ def test_machete_all_schedules(shape, types: TypeConfig): tensors = create_test_tensors(shape, types, group_size) print(f"MNK = {shape}") for schedule in ops.machete_supported_schedules( - types.act_type, - types.weight_type, - group_scales_type=types.group_scale_type, - group_zeros_type=types.group_scale_type, - out_type=types.output_type): + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type, + ): print(f"Testing schedule {schedule}") machete_mm_test_helper(types, tensors, group_size, schedule) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_heuristic(shape, types: TypeConfig): group_sizes: list[Optional[int]] = [] @@ -310,19 +350,22 @@ def test_machete_heuristic(shape, types: TypeConfig): # Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_machete_devices(device: str): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) @@ -335,29 +378,30 @@ def test_machete_devices(device: str): # Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_subset(): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) - - tensors = create_test_tensors((512, 4096, 4096), - type_config, - group_size, - subset_stride_factor=2) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) + + tensors = create_test_tensors( + (512, 4096, 4096), type_config, group_size, subset_stride_factor=2 + ) machete_mm_test_helper(type_config, tensors, group_size) # Test to make sure cuda graphs work class MacheteLayer(torch.nn.Module): - def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs @@ -366,8 +410,9 @@ def forward(self, a): return ops.machete_mm(a=a, **self.kwargs) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_cuda_graph(): m, n, k = 512, 4096, 4096 @@ -379,7 +424,8 @@ def test_machete_cuda_graph(): zero_points = False w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, b, wtype, stype, group_size, zero_points) + a.dtype, b, wtype, stype, group_size, zero_points + ) # Construct a trivial model with a single layer that calls a machete kernel model = MacheteLayer( diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 92914bd5cbba..9850332d385c 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -4,36 +4,61 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. """ + import pytest import torch - from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.quantization.utils import is_quant_method_supported + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) + MARLIN_QQQ_MAX_PARALLEL, + MARLIN_QQQ_MIN_THREAD_N, + MARLIN_QQQ_SUPPORTED_GROUP_SIZES, + MARLIN_QQQ_SUPPORTED_NUM_BITS, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, - query_marlin_supported_quant_types) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + MARLIN_SUPPORTED_GROUP_SIZES, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_scales, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_fp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, - marlin_weights) + MarlinWorkspace, + awq_marlin_quantize, + get_weight_perm, + marlin_quantize, + marlin_weights, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 - marlin_qqq_quantize) + marlin_qqq_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + awq_pack, + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] @@ -65,24 +90,27 @@ def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False, False)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - act_order, mnk_factors): +def test_gptq_marlin_repack( + k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors +): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -105,7 +133,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) # Pack to GPTQ format q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) @@ -118,11 +147,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.gptq_marlin_repack, - (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.gptq_marlin_repack, + (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -137,16 +169,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -161,21 +193,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights(b_weight, - quant_type, - group_size, - zero_points=True) + w_ref, q_w, s, zp = quantize_weights( + b_weight, quant_type, group_size, zero_points=True + ) # Pack to AWQ format q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.awq_marlin_repack, - (q_w_awq, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( @@ -189,14 +222,16 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) @pytest.mark.parametrize( - "group_size", - set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) + "group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -238,7 +273,8 @@ def test_gptq_marlin_gemm( if group_size != 16 or act_order: return w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( - b_weight.T, group_size) + b_weight.T, group_size + ) g_idx = None sort_indices = None marlin_zp = None @@ -247,8 +283,7 @@ def test_gptq_marlin_gemm( return if act_order: return - w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( - b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size) g_idx = None sort_indices = None marlin_zp = None @@ -257,7 +292,8 @@ def test_gptq_marlin_gemm( if group_size == 16: return w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) + b_weight, quant_type, group_size + ) g_idx = None sort_indices = None marlin_s2 = None @@ -265,18 +301,36 @@ def test_gptq_marlin_gemm( if group_size == 16: return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) marlin_zp = None marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, - sort_indices, workspace, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, - use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_gemm, + ( + a_input, + None, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + use_atomic_add, + use_fp32_reduce, + False, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = ops.gptq_marlin_gemm( a_input, @@ -308,23 +362,40 @@ def test_gptq_marlin_gemm( # TODO: find better way to test this? @torch.compile(fullgraph=True) -def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, size_n, - size_k): - return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, - size_n, size_k) +def marlin_24_gemm_tester( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, +): + return ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, + ) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -334,19 +405,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize( + b_weight, quant_type, group_size + ) - workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + workspace_24 = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) output_ref = torch.matmul(a_input, w_24_ref) - opcheck(torch.ops._C.gptq_marlin_24_gemm, - (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, - workspace_24.scratch, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1]), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_24_gemm, + ( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = marlin_24_gemm_tester( a_input, @@ -367,8 +450,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES) @@ -392,22 +477,22 @@ def test_hqq_marlin_gemm( a_input = rand_data((size_m, size_k)) dev = a_input.device - b_weight = torch.randint(0, - 10, (size_n, size_k), - dtype=torch.uint8, - device=dev) + b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev) scale = rand_data((size_n, size_k // group_size)) zero = rand_data((size_n, size_k // group_size)) gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n) sort_indices = torch.empty(0, dtype=torch.int, device=dev) - marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, - 4).to(dev) - marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n, - group_size).to(dev) - marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n, - group_size).to(dev) + marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to( + dev + ) + marlin_s = marlin_permute_scales( + scale.transpose(1, 0), size_k, size_n, group_size + ).to(dev) + marlin_zp = marlin_permute_scales( + zero.transpose(1, 0), size_k, size_n, group_size + ).to(dev) g_idx = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -438,8 +523,7 @@ def test_hqq_marlin_gemm( s_flat = scale.reshape(-1, 1) dequant = (b_flat - zp_flat) * s_flat - output_ref = torch.matmul(a_input, - dequant.reshape(b_weight.shape).transpose(1, 0)) + output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0)) torch.cuda.synchronize() @@ -448,8 +532,10 @@ def test_hqq_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("qqq"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("qqq"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) @@ -473,22 +559,34 @@ def test_marlin_qqq_gemm( b_weight = rand_data((size_k, size_n)) # Quantize activations - s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( - torch.float) - q_a = (a_input / s_a).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + s_a = ( + a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(torch.float) + ) + q_a = (a_input / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) # Quantize weights - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ - marlin_qqq_quantize(b_weight, num_bits, group_size) + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = ( + marlin_qqq_quantize(b_weight, num_bits, group_size) + ) - workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_MAX_PARALLEL) + workspace = MarlinWorkspace( + size_n, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_MAX_PARALLEL + ) - opcheck(torch.ops._C.marlin_qqq_gemm, - (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel, - marlin_qqq_s_group, workspace.scratch, a_input.shape[0], - b_weight.shape[1], a_input.shape[1])) + opcheck( + torch.ops._C.marlin_qqq_gemm, + ( + q_a, + marlin_qqq_q_w, + s_a, + marlin_qqq_s_channel, + marlin_qqq_s_group, + workspace.scratch, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ), + ) output = ops.marlin_qqq_gemm( q_a, @@ -518,11 +616,12 @@ def test_marlin_gemm_subset_input(): big_m = size_m * 2 big_k = size_k * 2 - a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8] + a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8] b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, False) + b_weight, quant_type, group_size, False + ) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) @@ -559,11 +658,12 @@ def test_marlin_gemm_opcheck(): size_m = 2048 size_n = 4096 size_k = 4096 - a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) - w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) - s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) - wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL).scratch + a = torch.rand((size_m, size_n), device="cuda", dtype=torch.float16) + w = torch.randint(-5, 5, (256, 8192), device="cuda", dtype=torch.int32) + s = torch.full((32, size_k), 0.125, device="cuda", dtype=torch.float16) + wk = MarlinWorkspace( + size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ).scratch x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) torch.testing.assert_close(x, y) diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 3a8f4c17598c..e9b091d06697 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -8,15 +8,27 @@ from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] -PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), - (90, 128), (150, 128), (150, 48), (90, 80)] +PAD_SHAPES = [ + (90, 64), + (150, 64), + (128, 48), + (128, 80), + (150, 80), + (90, 48), + (90, 128), + (150, 128), + (150, 48), + (90, 80), +] SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max @@ -31,7 +43,22 @@ # 0001 -> 0.5 # 0000 -> 0 E2M1_TO_FLOAT32 = [ - 0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6. + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] BLOCK_SIZE = 16 @@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale): assert x.ndim == 2 m, n = x.shape x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) @@ -131,7 +157,7 @@ def test_quantize_to_fp4( def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: dtype = torch.float16 current_platform.seed_everything(42) - torch.set_default_device('cuda:0') + torch.set_default_device("cuda:0") m, n = pad_shape diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 0b45c2298175..d2a352ce8445 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype from vllm import _custom_ops as ops from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] # m, n, k @@ -19,26 +20,31 @@ SHAPES.extend(PAD_SHAPES) SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] -def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, - m, n, dtype, block_size, device): +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): _, m_k = a_fp4.shape _, n_k = b_fp4.shape - assert (m_k == n_k) - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_nvfp4_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) return torch.matmul(a_in_dtype, b_in_dtype.t()) @@ -60,22 +66,31 @@ def test_nvfp4_gemm( a_dtype = torch.randn((m, k), dtype=dtype, device=device) b_dtype = torch.randn((n, k), dtype=dtype, device=device) - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) - b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) - alpha = 1. / (a_global_scale * b_global_scale) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) - expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, a_global_scale, - b_global_scale, m, n, dtype, block_size, - device) - out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, alpha, dtype) + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + out = ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) - torch.testing.assert_close(out, - expected_out.to(dtype=dtype), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 533a4fe59677..539689989fff 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch +from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant import vllm._custom_ops as ops -from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] @@ -20,8 +20,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @torch.inference_mode() def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): torch.manual_seed(seed) @@ -39,8 +38,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): @pytest.mark.parametrize("m", [8] + M) # m >= 8 @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() @@ -61,7 +59,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), - reason="only test for rocm fp8") + reason="only test for rocm fp8", +) def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) @@ -71,12 +70,9 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - ref_out = torch._scaled_mm(A, - B.t(), - out_dtype=dtype, - scale_a=scale_a, - scale_b=scale_b) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, - current_platform.get_cu_count()) + ref_out = torch._scaled_mm( + A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b + ) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count()) assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 8a2cc3baced2..26d49dad7396 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_triton_scaled_mm.py`. """ + import importlib from typing import Optional @@ -15,17 +16,19 @@ device = "cuda" triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") + "vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm" +) triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm -def torch_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = scale_a * out out = scale_b.T * out @@ -44,20 +47,22 @@ def get_8bit_types(): # This test is to check regressions for int8 support on ROCm. -@pytest.mark.parametrize("model_path", [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", -]) +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Should only run on ROCm") -def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, - max_tokens, num_logprobs): +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm") +def test_rocm_compressed_tensors_w8a8( + vllm_runner, example_prompts, model_path, max_tokens, num_logprobs +): dtype = "bfloat16" with vllm_runner(model_path, dtype=dtype) as vllm_model: - vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, - num_logprobs) + vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) @pytest.mark.parametrize("M", [1, 33, 64, 512]) @@ -68,10 +73,10 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, - use_scalar_scale_b, use_bias): - is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t - ).is_floating_point() +def test_scaled_mm( + M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias +): + is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point() current_platform.seed_everything(0) @@ -85,10 +90,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, # # So, the values here are kept small enough to avoid this situation. if is_floating_point_type(in_dtype): - a = (0.25 * torch.rand( - (M, K), dtype=torch.float32, device=device)).to(in_dtype) - b = (0.25 * torch.rand( - (K, N), dtype=torch.float32, device=device)).to(in_dtype) + a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype) + b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype) else: a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) @@ -105,7 +108,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, bias = None if use_bias: - bias = torch.rand((N, ), device=device, dtype=out_dtype) + bias = torch.rand((N,), device=device, dtype=out_dtype) c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) diff --git a/tests/kernels/test_apply_repetition_penalties.py b/tests/kernels/test_apply_repetition_penalties.py index 90380b872d6c..a4619f5846b1 100644 --- a/tests/kernels/test_apply_repetition_penalties.py +++ b/tests/kernels/test_apply_repetition_penalties.py @@ -4,8 +4,10 @@ import torch from tests.kernels.utils import opcheck -from vllm._custom_ops import (apply_repetition_penalties_cuda, - apply_repetition_penalties_torch) +from vllm._custom_ops import ( + apply_repetition_penalties_cuda, + apply_repetition_penalties_torch, +) from vllm.platforms import current_platform NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025] @@ -21,8 +23,9 @@ @pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties( num_seqs: int, @@ -32,7 +35,7 @@ def test_apply_repetition_penalties( seed: int, ) -> None: """ - Test the apply_repetition_penalties custom op + Test the apply_repetition_penalties custom op against a reference implementation. """ current_platform.seed_everything(seed) @@ -46,39 +49,40 @@ def test_apply_repetition_penalties( output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) # Mark some tokens as repeated in prompt and output - prompt_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) - output_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) + prompt_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) + output_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) for i in range(num_seqs): prompt_mask[i, prompt_indices[i]] = True output_mask[i, output_indices[i]] = True # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties_zero_seqs() -> None: """ @@ -104,22 +108,24 @@ def test_apply_repetition_penalties_zero_seqs() -> None: # No tokens to mark as repeated since num_seqs=0 # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index 2b745b84dae6..f305dffb6cd8 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -11,34 +11,29 @@ if not current_platform.has_device_capability(100): pytest.skip( reason="Cutlass MLA Requires compute capability of 10 or above.", - allow_module_level=True) + allow_module_level=True, + ) def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @@ -49,10 +44,11 @@ def ref_mla( @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("block_size", [16, 64, 128]) -def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, - varlen: bool, block_size: int): +def test_cutlass_mla_decode( + dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int +): torch.set_default_dtype(dtype) - torch.set_default_device('cuda') + torch.set_default_device("cuda") torch.manual_seed(42) d = 576 @@ -61,12 +57,12 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, q_nope_dim = 128 q_pe_dim = 64 - scale = (q_nope_dim + q_pe_dim)**(-0.5) + scale = (q_nope_dim + q_pe_dim) ** (-0.5) if varlen: seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) seq_lens = seq_lens.clip(2).to(torch.int32) else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) max_seq_len = seq_lens.max().item() block_num = (max_seq_len + block_size - 1) // block_size @@ -79,9 +75,7 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, # Amplify input values to ensure test coverage of edge cases where CUTLASS # kernel errors occur with split_k settings. q = torch.randn(bs, h_q, d) * 100 - block_table = torch.randint(0, - bs * block_num, (bs, block_num), - dtype=torch.int32) + block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) kv_cache = torch.randn(block_table.numel(), block_size, d) @@ -90,7 +84,8 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, out_ans = torch.zeros_like(out_ref) q_nope = q[:, :, :dv].clone() q_pe = q[:, :, dv:].clone() - ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, - block_table, scale) + ops.cutlass_mla_decode( + out_ans, q_nope, q_pe, kv_cache, seq_lens, block_table, scale + ) torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index e25556c89fb9..e67f04c07b3b 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -43,10 +43,9 @@ def test_flex_attention_vs_default_backend(monkeypatch): "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) + sampling_params = SamplingParams( + temperature=0.0, top_p=1.0, seed=seed, max_tokens=max_tokens + ) # Run with flex attention with monkeypatch.context() as m: @@ -76,8 +75,7 @@ def test_flex_attention_vs_default_backend(monkeypatch): output_default = llm_default.generate(prompts, sampling_params) # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): + for i, (flex_result, default_result) in enumerate(zip(output_flex, output_default)): prompt = prompts[i] flex_text = flex_result.outputs[0].text default_text = default_result.outputs[0].text @@ -85,7 +83,8 @@ def test_flex_attention_vs_default_backend(monkeypatch): assert flex_text == default_text, ( f"FlexAttention output doesn't match default for: {prompt!r}\n" f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + f"Default: {default_text!r}" + ) if __name__ == "__main__": diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index 803453a20d81..c79e6105e69f 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -13,13 +13,12 @@ NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] -def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, - scale: torch.Tensor) -> torch.Tensor: +def ref_impl( + silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor +) -> torch.Tensor: silu_and_mul_out = silu_and_mul.forward_native(x) out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) return out @@ -27,9 +26,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: out_shape = (x.shape[0], x.shape[1] // 2) - out = torch.empty(out_shape, - dtype=current_platform.fp8_dtype(), - device=x.device) + out = torch.empty(out_shape, dtype=current_platform.fp8_dtype(), device=x.device) torch.ops._C.silu_and_mul_quant(out, x, scale) return out @@ -57,7 +54,7 @@ def test_silu_and_mul( layer = SiluAndMul() # Make inputs - scale = (torch.randn((1), device=device, dtype=torch.float32)) + scale = torch.randn((1), device=device, dtype=torch.float32) x = torch.randn(num_tokens, hidden_size, dtype=dtype) ref_out = ref_impl(layer, x, scale) @@ -66,6 +63,7 @@ def test_silu_and_mul( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype assert ref_out.shape == ops_out.shape - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py index 1c31cfb25e5a..4b0bbb992d2e 100644 --- a/tests/kernels/test_triton_flash_attention.py +++ b/tests/kernels/test_triton_flash_attention.py @@ -4,21 +4,24 @@ Run `pytest tests/kernels/test_triton_flash_attention.py`. """ + import pytest import torch -from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, - MetaData, - compute_alibi_tensor, - scale_fp8, - triton_attention_rocm) +from vllm.attention.ops.triton_flash_attention import ( + SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm, +) from vllm.platforms import current_platform class ReferenceAttention: - - def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, - input_metadata): + def __init__( + self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ): self.Z = Z self.HQ = HQ self.HK = HK @@ -30,21 +33,23 @@ def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, self.input_metadata = input_metadata def fwd(self, q, k, v): - scores = torch.einsum('bhqd,bhkd->bhqk', q, - k).float() * self.input_metadata.sm_scale + scores = ( + torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale + ) if self.input_metadata.causal: - mask = torch.tril(torch.ones(self.N_CTX_Q, - self.N_CTX_K, - device="cuda"), - diagonal=self.N_CTX_K - self.N_CTX_Q) + mask = torch.tril( + torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q, + ) scores[:, :, mask == 0] = float("-inf") if self.input_metadata.bias is not None: scores += self.input_metadata.bias if self.use_alibi: - scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, - self.N_CTX_Q, self.N_CTX_K) + scores += compute_alibi_tensor( + self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K + ) p = torch.softmax(scores, dim=-1) if self.input_metadata.causal: @@ -54,31 +59,38 @@ def fwd(self, q, k, v): # should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v) # compare - if self.input_metadata.layout == 'bshd': + if self.input_metadata.layout == "bshd": ref_out = ref_out.transpose(1, 2).clone() return ref_out def fwd_fp8(self, q_quantized, k_quantized, v_quantized): q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( - self.dtype) + self.dtype + ) k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( - self.dtype) + self.dtype + ) v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( - self.dtype) + self.dtype + ) result = self.fwd(q, k, v) if self.input_metadata.o_scale is not None: result, _ = scale_fp8(result, self.input_metadata.o_scale) return result def fwd_fp8_kv(self, q, k_quantized, v_quantized): - k_descale, v_descale = (self.input_metadata.k_descale, - self.input_metadata.v_descale) - k_dequantized = (k_quantized.to(torch.float32) * - k_descale.to(torch.float32)).to(self.dtype) - v_dequantized = (v_quantized.to(torch.float32) * - v_descale.to(torch.float32)).to(self.dtype) + k_descale, v_descale = ( + self.input_metadata.k_descale, + self.input_metadata.v_descale, + ) + k_dequantized = ( + k_quantized.to(torch.float32) * k_descale.to(torch.float32) + ).to(self.dtype) + v_dequantized = ( + v_quantized.to(torch.float32) * v_descale.to(torch.float32) + ).to(self.dtype) return self.fwd(q, k_dequantized, v_dequantized) def varlen_fwd(self, q, k, v, is_mqa=False): @@ -86,29 +98,33 @@ def varlen_fwd(self, q, k, v, is_mqa=False): if is_mqa: # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so # the size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, - k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) - v_ref = v.view(v.shape[0], v.shape[1], 1, - v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) else: k_ref = k v_ref = v for i in range(0, self.input_metadata.num_contexts): - start_q, start_k = self.input_metadata.cu_seqlens_q[ - i], self.input_metadata.cu_seqlens_k[i] - end_q, end_k = self.input_metadata.cu_seqlens_q[ - i + 1], self.input_metadata.cu_seqlens_k[i + 1] + start_q, start_k = ( + self.input_metadata.cu_seqlens_q[i], + self.input_metadata.cu_seqlens_k[i], + ) + end_q, end_k = ( + self.input_metadata.cu_seqlens_q[i + 1], + self.input_metadata.cu_seqlens_k[i + 1], + ) k_curr = k_ref[start_k:end_k] v_curr = v_ref[start_k:end_k] if is_mqa: k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], - k_curr).float() - p = torch.softmax(scores * self.input_metadata.sm_scale, - dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr) return ref_out @@ -123,8 +139,7 @@ def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): # model. p_scale = None - o_scale = torch.rand(1, device="cuda", - requires_grad=False) if use_o_scale else None + o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale @@ -150,10 +165,10 @@ def input_helper( current_platform.seed_everything(0) # Initialize q, k, v - if layout == 'bhsd': + if layout == "bhsd": q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == 'bshd': + elif layout == "bshd": q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) @@ -161,69 +176,54 @@ def input_helper( # for n heads the set of slopes is the geometric sequence that starts # 2^(-8/n) alibi_slopes = torch.tensor( - [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + [2 ** (-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) + device="cuda", + ).repeat(Z, 1) else: alibi_slopes = None if use_bias: - bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), - dtype=dtype, - device="cuda", - requires_grad=False) + bias = torch.randn( + (1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False + ) else: bias = None - q = torch.randn(q_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - k = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - v = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) if is_fp8: - (q, k, v, q_descale, k_descale, v_descale, p_scale, - o_scale) = quantize_input(q, - k, - v, - use_o_scale=use_o_scale, - fp8_kv=fp8_kv) + (q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input( + q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv + ) else: q_descale = k_descale = v_descale = p_scale = o_scale = None - input_metadata = MetaData(sm_scale=D_HEAD**-0.5, - max_seqlens_q=N_CTX_Q, - max_seqlens_k=N_CTX_K, - layout=layout, - alibi_slopes=alibi_slopes, - alibi_batch=Z, - alibi_nheads=HQ, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=o_scale, - bias=bias, - seqlen_q=N_CTX_Q, - seqlen_k=N_CTX_K) + input_metadata = MetaData( + sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K, + ) return q, k, v, input_metadata -def varlen_input_helper(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - equal_seqlens=False): +def varlen_input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False +): current_platform.seed_everything(0) # Random sequence lengths. Using N_CTX as kind of max of sum of individual @@ -231,66 +231,72 @@ def varlen_input_helper(Z, if not equal_seqlens: max_seqlens_q = N_CTX_Q // Z max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, - max_seqlens_q + 1, (Z, ), - dtype=torch.int32) - seqlens_k = torch.randint(1, - max_seqlens_k + 1, (Z, ), - dtype=torch.int32) + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) else: - seqlens_q = torch.full((Z, ), N_CTX_Q // Z) - seqlens_k = torch.full((Z, ), N_CTX_K // Z) + seqlens_q = torch.full((Z,), N_CTX_Q // Z) + seqlens_k = torch.full((Z,), N_CTX_K // Z) # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_q.cumsum(dim=0, dtype=torch.int32) - ]) - cu_seqlens_k = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_k.cumsum(dim=0, dtype=torch.int32) - ]) + cu_seqlens_q = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32), + ] + ) + cu_seqlens_k = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32), + ] + ) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() total_k = cu_seqlens_k[-1].item() - q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() + q = ( + torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 48, 12, 1, 1, 64), - (4, 4, 4, 128, 128, 65), - (16, 48, 48, 1, 1, 128), - (64, 48, 24, 3, 3, 128), - (4, 4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('layout', ['bshd']) -def test_op_fwd(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - use_alibi, - layout, - dtype=torch.float16): +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_alibi", [True, False]) +@pytest.mark.parametrize("layout", ["bshd"]) +def test_op_fwd( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16 +): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - dtype, layout, use_alibi, causal) + q, k, v, input_metadata = input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal + ) o = torch.empty_like(q) @@ -299,48 +305,50 @@ def test_op_fwd(Z, # Transpose here if layout is bshd so we have same reference code for all # layouts - if layout == 'bshd': + if layout == "bshd": q = q.transpose(1, 2).clone() k = k.transpose(1, 2).clone() v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(k.shape[0], -1, k.shape[2], - k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(v.shape[0], -1, v.shape[2], - v.shape[3]) - - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - use_alibi, dtype, input_metadata) + k = ( + k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + ) + v = ( + v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + ) + + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -@pytest.mark.parametrize('use_o_scale', [True, False]) -@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), - reason="Triton FP8 requires CUDA 9.0 or higher") -def test_op_fwd_fp8(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - use_o_scale, - dtype=torch.float32): +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +@pytest.mark.parametrize("use_o_scale", [True, False]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher", +) +def test_op_fwd_fp8( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32 +): current_platform.seed_everything(0) # Disable grad to save memory it won't run into OOM on CI machine. @@ -358,95 +366,103 @@ def test_op_fwd_fp8(Z, causal=causal, layout=layout, is_fp8=True, - use_o_scale=use_o_scale) + use_o_scale=use_o_scale, + ) o = torch.empty_like(q_quantized) if use_o_scale else None - tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, - o, input_metadata) + tri_out, _ = triton_attention_rocm( + q_quantized, k_quantized, v_quantized, o, input_metadata + ) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) # compare - torch.testing.assert_close(ref_out.to(torch.float32), - tri_out.to(torch.float32), - atol=7e-2, - rtol=2e-1) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -def test_op_fwd_fp8_kv(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - dtype=torch.float32): + torch.testing.assert_close( + ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1 + ) + + +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +def test_op_fwd_fp8_kv( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32 +): current_platform.seed_everything(0) - q, k_quantized, v_quantized, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - causal=causal, - layout=layout, - is_fp8=True, - fp8_kv=True) + q, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True, + ) o = torch.empty_like(q) - tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, - input_metadata) + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_bias', [True]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_bias", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - layout='bhsd', - causal=causal, - use_bias=use_bias) + q, k, v, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout="bhsd", + causal=causal, + use_bias=use_bias, + ) o = torch.empty_like(q) # triton implementation tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) # compare @@ -454,47 +470,47 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), - (4, 48, 512, 64), - (16, 48, 512, 64), - (64, 48, 128, 128)]) -@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize( + "Z, H, N_CTX, D_HEAD", + [(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)], +) +@pytest.mark.parametrize("causal", [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, - D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, - input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), - (4, 48, 12, 256, 64), - (4, 48, 4, 512, 64), - (4, 64, 16, 128, 128)]) -@pytest.mark.parametrize('causal', [False]) -def test_op_varlen_mqa_fwd(Z, - HQ, - HK, - N_CTX, - D_HEAD, - causal, - dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, - D_HEAD, dtype) +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX, D_HEAD", + [ + (2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128), + ], +) +@pytest.mark.parametrize("causal", [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype + ) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2e8febbdcf26..1c9636d5c6a2 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -16,11 +16,14 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms.interface import _Backend -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) +from vllm.utils import ( + STR_BACKEND_ENV_VAR, + STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad, +) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -39,7 +42,7 @@ class QKVInputs(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, query/key/values and their sequence lengths. @@ -49,7 +52,7 @@ class QKVInputs(NamedTuple): num_heads x head_size) attention inputs * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -59,7 +62,7 @@ class QKVInputs(NamedTuple): class QKVO(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, alongside unpacked known-correct attention output @@ -69,14 +72,14 @@ class QKVO(NamedTuple): num_heads x head_size) attention inputs * ideal_output: unpacked (batch_size x padded_seq_len x num_heads x head_size) known-correct attention output - ''' + """ qkv: QKVInputs ideal_output: torch.Tensor class PackedQKVInputs(NamedTuple): - ''' + """ Data structure for representing packed attention inputs Attributes: @@ -88,7 +91,7 @@ class PackedQKVInputs(NamedTuple): packed tensor * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -100,7 +103,7 @@ class PackedQKVInputs(NamedTuple): class PackedQKVO(NamedTuple): - ''' + """ Data structure for representing packed attention inputs, alongside packed known-correct attention output @@ -110,28 +113,28 @@ class PackedQKVO(NamedTuple): x head_size) attention inputs * ideal_output: packed (number_of_tokens x num_heads x head_size) known-correct attention output - ''' + """ packed_qkv: Optional[PackedQKVInputs] ideal_output: torch.Tensor class KVMemoryMap(NamedTuple): - ''' + """ Data structure for encapsulating KV cache memory mapping. Attributes: * block_tables: KV cache block tables * slot_mapping: mapping of sequence offset to physical address - ''' + """ block_tables: torch.Tensor slot_mapping: torch.Tensor class PhaseTestParameters(NamedTuple): - ''' + """ Data structure for encapsulating the test parameters for a given test "phase" (prefill or decode phase) and attention scenario (encoder, decoder-self, encoder/decoder-cross) @@ -143,7 +146,7 @@ class PhaseTestParameters(NamedTuple): output * kv_mmap: KV cache memory mapping, specific to this test phase & attention scenario - ''' + """ packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] @@ -153,41 +156,43 @@ def maybe_make_int_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D int torch.Tensor on `device` Returns: * If _list is not None: 1D int torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.int, device=device) + ) def maybe_make_long_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D long torch.Tensor on `device` Returns: * If _list is not None: 1D long torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.long, device=device) + ) def maybe_max(_list: Optional[list]) -> Optional[Number]: - ''' + """ Returns: * If _list is not None: max(_list) * None otherwise - ''' + """ return None if _list is None else max(_list) @@ -195,7 +200,7 @@ def make_causal_mask( q_max_seq_len: int, kv_max_seq_len: int, ) -> torch.Tensor: - ''' + """ Create a q_max_seq_len x kv_max_seq_len causal mask Arguments: @@ -206,19 +211,19 @@ def make_causal_mask( Returns: * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' + """ # Create a matrix where entry (i, j) is True if i >= j mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) + mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) return mask -def override_backend_env_variable(mpatch: pytest.MonkeyPatch, - backend_name: str) -> None: - ''' +def override_backend_env_variable( + mpatch: pytest.MonkeyPatch, backend_name: str +) -> None: + """ Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits. @@ -227,18 +232,20 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * mpatch: pytest monkeypatch instance * backend_name: attention backend name to force - ''' + """ mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) -def ref_masked_attention(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[list] = None, - kv_seq_lens: Optional[list] = None) -> torch.Tensor: - ''' +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[list] = None, + kv_seq_lens: Optional[list] = None, +) -> torch.Tensor: + """ "Golden" masked attention reference. Supports two types of masking: * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out @@ -260,14 +267,14 @@ def ref_masked_attention(query: torch.Tensor, Returns: * Attention result, batch_size x q_padded_seq_len x num_heads x head_size - ''' + """ assert q_seq_lens is not None assert kv_seq_lens is not None batch_size = query.shape[0] - assert (len(q_seq_lens) == batch_size) - assert (len(kv_seq_lens) == batch_size) + assert len(q_seq_lens) == batch_size + assert len(kv_seq_lens) == batch_size attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() @@ -303,7 +310,7 @@ def make_qkv( attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple[QKVInputs, QKVInputs, QKVInputs]: - ''' + """ Construct QKV test tensors for self- and cross-attention. Generates three query/key/value triplets: @@ -340,14 +347,12 @@ def make_qkv( * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) * Prefill QKVInputs structure (containing all but the last sequence offset) * Decode QKVInputs structure (containing all only the last sequence offset) - ''' + """ if force_max_len: q_seq_lens = [max_q_seq_len for _ in range(batch_size)] else: - q_seq_lens = [ - random.randint(2, max_q_seq_len) for _ in range(batch_size) - ] + q_seq_lens = [random.randint(2, max_q_seq_len) for _ in range(batch_size)] kv_seq_lens = None if force_kv_seq_lens is not None: kv_seq_lens = force_kv_seq_lens @@ -360,50 +365,44 @@ def make_qkv( if force_max_len: kv_seq_lens = [max_kv_seq_len] * batch_size else: - kv_seq_lens = [ - random.randint(2, max_kv_seq_len) for _ in range(batch_size) - ] - - query = torch.rand( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - key = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - value = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - decode_query = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + kv_seq_lens = [random.randint(2, max_kv_seq_len) for _ in range(batch_size)] + + query = torch.rand((batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + prefill_query = torch.zeros((batch_size, max_q_seq_len, num_heads, head_size)).to( + device + ) + prefill_key = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) + prefill_value = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) + + decode_query = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - decode_value = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, - kv_seq_lens)): + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): query[bdx, q_seq_len:, :, :] = 0 key[bdx, kv_seq_len:, :, :] = 0 value[bdx, kv_seq_len:, :, :] = 0 - prefill_query[bdx, - 0:(q_seq_len - 1), :, :] = query[bdx, - 0:(q_seq_len - 1), :, :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :, :] = key[bdx, - 0:(kv_seq_len - 1), :, :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] - - decode_query[bdx, :, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :, :] - decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] - decode_value[bdx, :, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :, :] + prefill_query[bdx, 0 : (q_seq_len - 1), :, :] = query[ + bdx, 0 : (q_seq_len - 1), :, : + ] + prefill_key[bdx, 0 : (kv_seq_len - 1), :, :] = key[ + bdx, 0 : (kv_seq_len - 1), :, : + ] + prefill_value[bdx, 0 : (kv_seq_len - 1), :, :] = value[ + bdx, 0 : (kv_seq_len - 1), :, : + ] + + decode_query[bdx, :, :, :] = query[bdx, (q_seq_len - 1) : q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -417,25 +416,29 @@ def make_qkv( key, value, q_seq_lens, - kv_seq_lens), + kv_seq_lens, + ), QKVInputs( prefill_query, # Prefill subset of QKV sequences prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens), + prefill_kv_seq_lens, + ), QKVInputs( decode_query, # Decode subset of KV sequences decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens)) + decode_kv_seq_lens, + ), + ) def pack_tensor( - unpacked_tensor: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]: - ''' + unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str] +) -> tuple[torch.Tensor, list[int]]: + """ Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where number_of_tokens = sum(seq_lens) @@ -451,7 +454,7 @@ def pack_tensor( * packed_tensor: number_of_tokens x num_heads x head_size * start_loc_list: start idx of each batch elt in packed_tensor; [0] + list(itertools.accumulate(seq_lens)) - ''' + """ num_tok = sum(seq_lens) num_heads = unpacked_tensor.shape[-2] @@ -460,16 +463,15 @@ def pack_tensor( packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): - - packed_tensor[start_loc:( - start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + packed_tensor[start_loc : (start_loc + seq_len), :, :] = unpacked_tensor[ + bdx, :seq_len, :, : + ] return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, device: Union[torch.device, - str]) -> PackedQKVInputs: - ''' +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs: + """ Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x num_heads x head_size tensors. @@ -488,28 +490,30 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, * Packed (number_of_tokens x num_heads x head_size) QKV inputs derived from unpacked inputs - ''' + """ if qkv.query is None: packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(qkv.query, - qkv.q_seq_lens, - device=device) - packed_key, kv_start_loc_list = pack_tensor(qkv.key, - qkv.kv_seq_lens, - device=device) + packed_query, q_start_loc_list = pack_tensor( + qkv.query, qkv.q_seq_lens, device=device + ) + packed_key, kv_start_loc_list = pack_tensor(qkv.key, qkv.kv_seq_lens, device=device) packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) return PackedQKVInputs( - packed_query, packed_key, packed_value, q_start_loc_list, + packed_query, + packed_key, + packed_value, + q_start_loc_list, kv_start_loc_list, (None if q_start_loc_list is None else qkv.q_seq_lens), - qkv.kv_seq_lens) + qkv.kv_seq_lens, + ) def make_backend(backend_name: str) -> AttentionBackend: - ''' + """ Construct the backend instance determined by the backend_name string argument. @@ -527,17 +531,18 @@ def make_backend(backend_name: str) -> AttentionBackend: Returns: * Backend instance - ''' + """ if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend + return XFormersBackend() elif backend_name == STR_FLASH_ATTN_VAL: from vllm.attention.backends.flash_attn import FlashAttentionBackend + return FlashAttentionBackend() - raise AssertionError( - f"Unrecognized backend_name {backend_name} for unit test") + raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test") def _make_metadata_tensors( @@ -545,9 +550,17 @@ def _make_metadata_tensors( context_lens: Optional[list[int]], encoder_seq_lens: Optional[list[int]], device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], - torch.Tensor, torch.Tensor, Optional[int]]: - ''' +) -> tuple[ + torch.Tensor, + torch.Tensor, + Any, + Any, + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + Optional[int], +]: + """ Build scalar & tensor values required to build attention metadata structure. Arguments: @@ -567,48 +580,61 @@ def _make_metadata_tensors( * encoder_seq_lens_tensor: encoder seq_lens list, as tensor * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor - ''' + """ seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) context_lens_tensor = maybe_make_int_tensor(context_lens, device) max_context_len = maybe_max(context_lens) max_seq_len = maybe_max(seq_lens) encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) - max_encoder_seq_len = (None if encoder_seq_lens is None else - max(encoder_seq_lens)) + max_encoder_seq_len = None if encoder_seq_lens is None else max(encoder_seq_lens) seq_start_loc = None if seq_lens_tensor is not None: - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=seq_lens_tensor.device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=encoder_seq_lens_tensor.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, - max_encoder_seq_len) - - -def make_kv_cache(num_blocks: int, - num_heads: int, - head_size: int, - block_size: int, - device: Union[torch.device, str], - backend: str, - default_val: float = 0.0) -> torch.Tensor: - ''' + seq_start_loc = torch.zeros( + seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device, + ) + torch.cumsum( + seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:] + ) + + encoder_seq_start_loc = torch.zeros( + encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device, + ) + torch.cumsum( + encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:], + ) + + return ( + seq_lens_tensor, + context_lens_tensor, + max_context_len, + max_seq_len, + seq_start_loc, + encoder_seq_lens_tensor, + encoder_seq_start_loc, + max_encoder_seq_len, + ) + + +def make_kv_cache( + num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + backend: str, + default_val: float = 0.0, +) -> torch.Tensor: + """ Create a fake KV cache. Arguments: @@ -626,27 +652,29 @@ def make_kv_cache(num_blocks: int, * for backend 'XFORMERS' * kv_cache: 2 x num_blocks x block_size x num_heads x head_size * for backend 'FLASH_ATTN' - ''' - if backend == 'XFORMERS': - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) - elif backend == 'FLASH_ATTN': - kv_cache = torch.rand( - (2, num_blocks, block_size, num_heads, head_size)).to(device) + """ + if backend == "XFORMERS": + kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to( + device + ) + elif backend == "FLASH_ATTN": + kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to( + device + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: - ''' + """ Compute the minimum number of blocks required to hold num_tokens tokens, given block_size - ''' + """ return (num_tokens + block_size) // block_size @@ -658,9 +686,12 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]): - ''' +def split_slot_mapping( + slot_mapping_list: torch.Tensor, + seq_lens: list[int], + device: Union[torch.device, str], +): + """ Split a slot mapping into valid prefill- and decode-phase slot mappings. Context: @@ -698,28 +729,32 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], reflecting all N prefill prompts * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting all N decoded tokens - ''' + """ prefill_slot_mapping = [] decode_slot_mapping = [] base_idx = 0 for seq_len in seq_lens: - prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + - seq_len - 1)]) + prefill_slot_mapping.extend( + slot_mapping_list[base_idx : (base_idx + seq_len - 1)] + ) decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len - return (maybe_make_long_tensor(prefill_slot_mapping, device), - maybe_make_long_tensor(decode_slot_mapping, device)) + return ( + maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device), + ) def make_block_tables_slot_mapping( - block_size: int, - seq_lens: list[int], - device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]: - ''' + block_size: int, + seq_lens: list[int], + device: Union[torch.device, str], + block_base_addr: int = 0, +) -> tuple[torch.Tensor, list[int], int]: + """ Construct fake block tables & slot mappings. For a sequence with num_tokens tokens the minimum number @@ -756,12 +791,11 @@ def make_block_tables_slot_mapping( * block_tables_tensor: block table for sequence * slot_mapping_list: slot mapping for sequence * max_block_idx: the highest block address within this block table - ''' + """ # Provision minimum number of KV cache blocks num_blocks_list = [ - _num_tokens_to_min_blocks(num_tokens, block_size) - for num_tokens in seq_lens + _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -774,11 +808,11 @@ def make_block_tables_slot_mapping( max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] - block_table = list( - range(block_base_idx, block_base_idx - num_blocks, -1)) + block_table = list(range(block_base_idx, block_base_idx - num_blocks, -1)) for idx in range(num_tokens): - mapping_value = ( - idx % block_size) + block_table[idx // block_size] * block_size + mapping_value = (idx % block_size) + block_table[ + idx // block_size + ] * block_size slot_mapping_list.append(mapping_value) block_base_idx -= num_blocks @@ -802,9 +836,9 @@ def make_test_metadata( decoder_test_params: Optional[PhaseTestParameters], device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, - cross_test_params: Optional[PhaseTestParameters] = None + cross_test_params: Optional[PhaseTestParameters] = None, ) -> AttentionMetadata: - ''' + """ Construct fake attention metadata for a given test phase (prefill-phase or decode-phase). @@ -841,13 +875,12 @@ def make_test_metadata( Return: * AttentionMetadata structure - ''' + """ # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None - kv_mmap = (None - if decoder_test_params is None else decoder_test_params.kv_mmap) + kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap # This function constructs metadata assuming no chunked prefill, # i.e. 100% prefill tokens or 100% decode tokens @@ -860,10 +893,11 @@ def make_test_metadata( # seq_lens is None signals encoder-only # scenario, in which case num_prefills_or_decodes and # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) + num_prefills_or_decodes = None if seq_lens is None else len(seq_lens) - num_prefill_or_decode_tokens = (None if seq_lens is None else ( - sum(seq_lens) if is_prompt else len(seq_lens))) + num_prefill_or_decode_tokens = ( + None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens)) + ) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -877,8 +911,9 @@ def make_test_metadata( # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = (None if encoder_seq_lens is None else - (sum(encoder_seq_lens))) + num_encoder_tokens = ( + None if encoder_seq_lens is None else (sum(encoder_seq_lens)) + ) if cross_test_params is None: cross_kv_mmap = None @@ -906,10 +941,9 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), @@ -930,10 +964,13 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) else: # not is_prompt # Decode-phase scenario @@ -955,10 +992,9 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, @@ -981,16 +1017,19 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) - - -def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor, - backend: str) -> None: - ''' + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) + + +def assert_actual_matches_ideal( + test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str +) -> None: + """ Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -998,24 +1037,24 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * test_params: Test parameters including packed ideal output * output_under_test: actually observed output value - ''' + """ ideal_output = test_params.packed_qkvo.ideal_output - if backend == 'XFORMERS': - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == "XFORMERS": + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output) + ) - elif backend == 'FLASH_ATTN': + elif backend == "FLASH_ATTN": # For FlashAttention override the accuracy thresholds to non default # values since we notice a higher difference between the ideal and # actual output. - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output), - atol=0.01, - rtol=0.016) + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016 + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) # Copied/modified from torch._refs.__init__.py @@ -1029,19 +1068,15 @@ def fp8_allclose( """ Reference implementation of torch.allclose """ - torch._refs._check_close_args(name="torch.allclose", - a=a, - b=b, - rtol=rtol, - atol=atol) + torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) return bool( torch.all( - torch.isclose(a.double(), - b.double(), - rtol=rtol, - atol=atol, - equal_nan=equal_nan)).item()) + torch.isclose( + a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ).item() + ) # Marlin MoE test utils @@ -1054,7 +1089,8 @@ def stack_and_dev(tensors: list[torch.Tensor]): def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def torch_experts( @@ -1074,10 +1110,11 @@ def torch_experts( block_shape: Optional[list[int]] = None, apply_router_weights_on_input: bool = False, ) -> torch.Tensor: - assert (global_num_experts == -1 - or (global_num_experts == w1.shape[0] and expert_map is None) - or (expert_map is not None - and global_num_experts == expert_map.shape[0])) + assert ( + global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None and global_num_experts == expert_map.shape[0]) + ) M, K = a.shape topk = topk_ids.shape[1] @@ -1092,8 +1129,9 @@ def torch_experts( if a1_scale: assert not per_act_token_quant and block_shape is None - a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype, - per_act_token_quant, block_shape) + a, a_scale = moe_kernel_quantize_input( + a, a1_scale, quant_dtype, per_act_token_quant, block_shape + ) num_experts = w1.shape[0] @@ -1112,22 +1150,28 @@ def torch_experts( out[mask] = tmp2 @ w2[i].transpose(0, 1) elif block_shape is not None: # block quantized - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) - tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], - w1_scale[i], block_shape, - out.dtype) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) + tmp1 = native_w8a8_block_matmul( + a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype + ) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) - out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, - w2_scale[i], block_shape, - out.dtype) + out[mask] = native_w8a8_block_matmul( + tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype + ) else: - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) scales = a_scale if a_scale.numel() == 1 else a_scale[mask] tmp1 = a[mask].to(f32) * scales @@ -1137,8 +1181,8 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1).to(out.dtype) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) assert b_scale is not None tmp2 = tmp2.to(f32) * b_scale @@ -1148,21 +1192,27 @@ def torch_experts( if apply_router_weights_on_input: return out else: - return (out.view(M, -1, w2.shape[1]).to(f32) * - topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) + return ( + (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1)) + .sum(dim=1) + .to(out.dtype) + ) -def torch_moe(a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) - return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - expert_map) + return torch_experts( + a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map + ) def torch_moe_single(a, w, score, topk): @@ -1181,41 +1231,49 @@ def torch_moe_single(a, w, score, topk): # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. -def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, - torch._library.custom_ops.CustomOpDef], - args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, - *, - test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, - raise_exception: bool = True, - cond: bool = True) -> dict[str, str]: - with unittest.mock.patch('torch.allclose', new=fp8_allclose): - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} +def opcheck( + op: Union[ + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + torch._library.custom_ops.CustomOpDef, + ], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, + raise_exception: bool = True, + cond: bool = True, +) -> dict[str, str]: + with unittest.mock.patch("torch.allclose", new=fp8_allclose): + return ( + torch.library.opcheck( + op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception + ) + if cond + else {} + ) # For testing quantized linear kernels def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor): return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) -def baseline_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - +def baseline_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting # in numpy simply stretches dimensions with an extent of 1 to match the # the target shape by repeating the data along that dimension (broadcasting) @@ -1234,16 +1292,19 @@ def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: assert s % t.shape[i] == 0 - t = t.unsqueeze(i + 1)\ - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ - .flatten(i, i + 1) + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) + .flatten(i, i + 1) + ) return t scale_a = group_broadcast(scale_a, a.shape) scale_b = group_broadcast(scale_b, b.shape) - output = torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))).to(out_dtype) + output = torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ).to(out_dtype) if bias is not None: output = output + bias diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py index 9f2229cc41df..1d24851bb782 100644 --- a/tests/kv_transfer/test_disagg.py +++ b/tests/kv_transfer/test_disagg.py @@ -19,8 +19,11 @@ def setup_servers(): pytest.skip("Skipping test: fewer than 2 GPUs available") # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", - shell=True).decode().strip() + VLLM_HOST_IP = ( + subprocess.check_output("hostname -I | awk '{print $1}'", shell=True) + .decode() + .strip() + ) os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP # Start prefill instance @@ -37,7 +40,7 @@ def setup_servers(): "--max-model-len", "1000", "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",' '"kv_rank":0,"kv_parallel_size":2}', ] prefill_env = os.environ.copy() @@ -58,7 +61,7 @@ def setup_servers(): "--max-model-len", "1000", "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",' '"kv_rank":1,"kv_parallel_size":2}', ] decode_env = os.environ.copy() @@ -98,23 +101,27 @@ def wait_for_server(port, timeout=240): @pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) def test_disaggregated_prefilling(prompt): # Send to prefill - response = requests.post("http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - }) + response = requests.post( + "http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0, + }, + ) assert response.status_code == 200 # Send to decode - response = requests.post("http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - }) + response = requests.post( + "http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0, + }, + ) assert response.status_code == 200 diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 352ab63552de..ff96527318b8 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -8,8 +8,7 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe # TODO: the test depends on a lot of fields in the current implementation. @@ -17,7 +16,6 @@ def test_run(my_rank, buffer, device): - # buffer should be empty in the beginning if my_rank == 0: assert buffer.buffer_size == 0 @@ -27,7 +25,7 @@ def test_run(my_rank, buffer, device): # insert tokens = torch.tensor([1, 2, 3]).to(device) - roi = (tokens > 0) + roi = tokens > 0 if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(device) value = 3.0 * torch.ones([5, 6]).to(device) @@ -55,7 +53,6 @@ def test_run(my_rank, buffer, device): def stress_test(my_rank, buf, device): - torch.distributed.barrier() torch.manual_seed(100) @@ -66,7 +63,8 @@ def stress_test(my_rank, buf, device): torch.rand(100).to(device), # key torch.rand(100).to(device), # value torch.rand(100).to(device), # hidden - ) for i in tqdm(range(200)) + ) + for i in tqdm(range(200)) ] random.seed(my_rank) @@ -115,12 +113,11 @@ def stress_test(my_rank, buf, device): if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) @@ -128,8 +125,8 @@ def stress_test(my_rank, buf, device): print(f"initialized! My rank is {my_rank}") config = KVTransferConfig( - kv_connector='PyNcclConnector', - kv_buffer_device='cuda', + kv_connector="PyNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test @@ -160,4 +157,4 @@ def stress_test(my_rank, buf, device): buffer.close() data_pipe.close() cpu_pipe.close() - print('Done') + print("Done") diff --git a/tests/kv_transfer/test_module.py b/tests/kv_transfer/test_module.py index 7a04174870da..b9a28e4bceb7 100644 --- a/tests/kv_transfer/test_module.py +++ b/tests/kv_transfer/test_module.py @@ -9,21 +9,19 @@ def run_python_script(script_name, timeout): - script_name = f'kv_transfer/{script_name}' + script_name = f"kv_transfer/{script_name}" try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "0"}, # Set the RANK environment variable for process 0 + env={"RANK": "0"}, # Set the RANK environment variable for process 0 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) process1 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "1"}, # Set the RANK environment variable for process 1 + env={"RANK": "1"}, # Set the RANK environment variable for process 1 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) @@ -34,11 +32,9 @@ def run_python_script(script_name, timeout): # Check the return status of both processes if process0.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=0, {process0.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}") if process1.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=1, {process1.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}") except subprocess.TimeoutExpired: # If either process times out, terminate both and fail the test @@ -53,15 +49,14 @@ def run_python_script(script_name, timeout): @pytest.mark.parametrize( "script_name,timeout", [ - ("test_lookup_buffer.py", - 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120) # First test case with a 120-second timeout - ]) + ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120), # First test case with a 120-second timeout + ], +) def test_run_python_script(script_name, timeout): # Check the number of GPUs if torch.cuda.device_count() < 2: - pytest.skip( - f"Skipping test {script_name} because <2 GPUs are available") + pytest.skip(f"Skipping test {script_name} because <2 GPUs are available") # Run the test if there are at least 2 GPUs run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 32116608a217..16ae4ad2ee9f 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -15,7 +15,7 @@ def test_run(my_rank, pipe): print(f"rank {my_rank} test_run starts....") # test run x = torch.tensor([1]).to(pipe.device) - y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device) if my_rank == 0: pipe.send_tensor(x) print(f"rank {my_rank} sent tensor x") @@ -53,9 +53,8 @@ def stress_test(my_rank, pipe): for i in tqdm(range(500)): mean = torch.rand(1).item() * 100 std = torch.rand(1).item() * 100 - size = torch.randint(900, 1000, (2, )) - x = torch.normal(mean * 1.0, std * 1.0, - size=size.tolist()).to(pipe.device) + size = torch.randint(900, 1000, (2,)) + x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) # 5% probability of sending a None if torch.rand(1).item() < 0.05: @@ -96,20 +95,16 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() for i in tqdm(range(500)): - tensors = [] if my_rank == 0: # create tensor - tensors = [ - torch.rand(nelement).to(pipe.device) for _ in range(ntensor) - ] + tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] torch.distributed.barrier() if my_rank == 0: - t = torch.tensor([time.time()], - dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -121,24 +116,23 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() - print('Latency test passed.') - print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + print("Latency test passed.") + print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms") if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) config = KVTransferConfig( - kv_connector='PyNcclConnector', - kv_buffer_device='cuda', + kv_connector="PyNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 881d5efa6919..7b836f765403 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,12 +12,16 @@ import vllm from vllm.config import LoRAConfig -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -51,11 +55,13 @@ def dist_init(): if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" - init_distributed_environment(world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend=backend) + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend, + ) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -70,10 +76,9 @@ def dist_init_torch_only(): backend = "gloo" temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group(world_size=1, - rank=0, - init_method=f"file://{temp_file}", - backend=backend) + torch.distributed.init_process_group( + world_size=1, rank=0, init_method=f"file://{temp_file}", backend=backend + ) class DummyLoRAModel(nn.Sequential, SupportsLoRA): @@ -83,25 +88,31 @@ class DummyLoRAModel(nn.Sequential, SupportsLoRA): @pytest.fixture def dummy_model() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", ColumnParallelLinear(50, 10)), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()), + ] + ) + ) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} return model @@ -110,25 +121,31 @@ def dummy_model() -> nn.Module: @pytest.fixture def dummy_model_gate_up() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ("sampler", Sampler()), + ] + ) + ) model.config = MagicMock() model.packed_modules_mapping = { "gate_up_proj": [ @@ -232,8 +249,7 @@ def llama_2_7b_engine_extra_embeddings(): get_model_old = get_model def get_model_patched(**kwargs): - kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, - max_lora_rank=8) + kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, max_lora_rank=8) return get_model_old(**kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): @@ -245,8 +261,9 @@ def get_model_patched(**kwargs): @pytest.fixture def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): - yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. - model_runner.model) + yield ( + llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.model_runner.model + ) @pytest.fixture diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index cc8160b2860d..a7d95ef19ccf 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -8,7 +8,8 @@ import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -27,14 +28,10 @@ def get_lora_requests(lora_path) -> list[LoRARequest]: return lora_requests -async def requests_processing_time(llm, - lora_requests: list[LoRARequest]) -> float: - - sampling_params = SamplingParams(n=1, - temperature=0.0, - top_p=1.0, - ignore_eos=True, - max_tokens=1) +async def requests_processing_time(llm, lora_requests: list[LoRARequest]) -> float: + sampling_params = SamplingParams( + n=1, temperature=0.0, top_p=1.0, ignore_eos=True, max_tokens=1 + ) generators = [] start = time.perf_counter() @@ -42,11 +39,11 @@ async def requests_processing_time(llm, for lora_request in lora_requests: lora_int_id = lora_request.lora_int_id generator = llm.generate( - prompt=TextPrompt(prompt=f"hello {lora_int_id}", - multi_modal_data=None), # type: ignore + prompt=TextPrompt(prompt=f"hello {lora_int_id}", multi_modal_data=None), # type: ignore sampling_params=sampling_params, lora_request=lora_request, - request_id=f"test{lora_int_id}") + request_id=f"test{lora_int_id}", + ) generators.append(generator) all_gens = merge_async_iterators(*generators) @@ -59,13 +56,13 @@ async def requests_processing_time(llm, @pytest.mark.asyncio async def test_add_lora(chatglm3_lora_files): - """ + """ The add_lora function is used to pre-load some LoRA adapters into the engine in anticipation of future requests using these adapters. To test this functionality, we use the async engine to process some requests - We do it twice, once with add_lora() pre-loading and once without. - We measure the request processing time in both cases and expect the time + We measure the request processing time in both cases and expect the time to be lesser in the case with add_lora() calls. """ lora_requests: list[LoRARequest] = get_lora_requests(chatglm3_lora_files) @@ -79,18 +76,18 @@ async def test_add_lora(chatglm3_lora_files): max_loras=max_loras, max_lora_rank=LORA_RANK, max_model_len=128, - gpu_memory_utilization=0.8, #avoid OOM + gpu_memory_utilization=0.8, # avoid OOM trust_remote_code=True, - enforce_eager=True) + enforce_eager=True, + ) # split lora_requests into 3 parts part_size = len(lora_requests) // 3 dummy_run_requests = lora_requests[:part_size] - warmup_run_requests = lora_requests[part_size:part_size * 2] - cold_run_requests = lora_requests[part_size * 2:] + warmup_run_requests = lora_requests[part_size : part_size * 2] + cold_run_requests = lora_requests[part_size * 2 :] async with build_async_engine_client_from_engine_args(engine_args) as llm: - # Dummy run - So any 1-time functionality like triton kernel compilation # is complete here. await requests_processing_time(llm, dummy_run_requests) @@ -104,18 +101,16 @@ async def test_add_lora(chatglm3_lora_files): else: # No way to check V0 engine results as the calls just return None. pass - time_with_add_lora = await requests_processing_time( - llm, warmup_run_requests) + time_with_add_lora = await requests_processing_time(llm, warmup_run_requests) # Run without any warmup - time_cold_start = await requests_processing_time( - llm, cold_run_requests) + time_cold_start = await requests_processing_time(llm, cold_run_requests) - print(f"time hot-start {time_with_add_lora} vs " - f"time cold-start {time_cold_start} ") + print(f"time hot-start {time_with_add_lora} vs time cold-start {time_cold_start} ") assert time_with_add_lora < time_cold_start, ( f"time_with_add_lora={time_with_add_lora}, " f"time_cold_start={time_cold_start}" "The engine request processing time with LoRA pre-loading " - "must be less than the version that does on-demand LoRA loading.") + "must be less than the version that does on-demand LoRA loading." + ) diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 774ebb9db210..3eb8c81f2261 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -16,12 +16,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query="What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 ), PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + query="Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 ), ] print(prompts) @@ -29,8 +27,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -42,12 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + ) expected_lora_output = [ "SELECT count(*) FROM singer", @@ -64,31 +64,36 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): +def test_baichuan_tensor_parallel_equality( + baichuan_lora_files, num_gpus_available, fully_sharded +): if num_gpus_available < 4: pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) + llm_tp1 = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) del llm_tp1 cleanup_dist_env_and_memory() - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) + llm_tp2 = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) del llm_tp2 @@ -96,14 +101,16 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, assert output_tp1 == output_tp2 - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) + llm_tp4 = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) del llm_tp4 diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 5481b413b8f5..8495d8e8c168 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -21,20 +21,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query="What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 ), PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + query="Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -47,13 +45,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @create_new_process_for_each_test() def test_chatglm3_lora(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -66,15 +66,17 @@ def test_chatglm3_lora(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -87,15 +89,17 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py index f615ceda76b5..1a5b9ba3641d 100644 --- a/tests/lora/test_default_mm_loras.py +++ b/tests/lora/test_default_mm_loras.py @@ -32,15 +32,12 @@ "max_lora_rank": 320, "max_model_len": 12800, "gpu_memory_utilization": 0.8, - "limit_mm_per_prompt": { - "audio": 1 - }, + "limit_mm_per_prompt": {"audio": 1}, "enforce_eager": True, } -def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, - **kwargs): +def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, **kwargs): inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])] # Apply any additional kwargs as overrides to the base kwargs @@ -53,11 +50,11 @@ def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, max_tokens=128, audios=audios, lora_request=lora_request, - ) for prompts, audios in inputs + ) + for prompts, audios in inputs ] - assert vllm_outputs_with_default_lora[-1][-1][-1].endswith( - expected_suffix) + assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(expected_suffix) def test_active_default_mm_lora( diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc2..9b7c78aeefda 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -15,29 +15,42 @@ from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA, +) + # yapf conflicts with isort for this block # yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LogitsProcessorWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers import ( + BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + LoRAMapping, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA, +) + # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) + ParallelLMHead, + VocabParallelEmbedding, + get_masked_input_and_mask, +) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -51,11 +64,14 @@ pytestmark = pytest.mark.skipif( not (current_platform.is_cuda_alike() or current_platform.is_cpu()), - reason="Backend not supported") + reason="Backend not supported", +) -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) # prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -68,8 +84,8 @@ @pytest.fixture(autouse=True) def clean_cache_reset_device(reset_default_device): # Release any memory we might be holding on to. CI runs OOMs otherwise. - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT + _LORA_B_PTR_DICT.clear() _LORA_A_PTR_DICT.clear() @@ -79,13 +95,14 @@ def clean_cache_reset_device(reset_default_device): @pytest.fixture(autouse=True) def skip_cuda_with_stage_false(request): """ - On cuda-like platforms, we use the same kernels for prefill and decode + On cuda-like platforms, we use the same kernels for prefill and decode stage, and 'stage' is generally ignored, so we only need to test once. """ if current_platform.is_cuda_alike(): try: if hasattr(request.node, "callspec") and hasattr( - request.node.callspec, "params"): + request.node.callspec, "params" + ): params = request.node.callspec.params if "stage" in params and params["stage"] is False: pytest.skip("Skip test when stage=False") @@ -94,9 +111,9 @@ def skip_cuda_with_stage_false(request): yield -def get_random_id_to_index(num_loras: int, - num_slots: int, - log: bool = True) -> list[Optional[int]]: +def get_random_id_to_index( + num_loras: int, num_slots: int, log: bool = True +) -> list[Optional[int]]: """Creates a random lora_id_to_index mapping. Args: @@ -109,7 +126,8 @@ def get_random_id_to_index(num_loras: int, if num_loras > num_slots: raise ValueError( f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " - "num_loras must be less than or equal to num_slots.") + "num_loras must be less than or equal to num_slots." + ) slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() @@ -158,19 +176,18 @@ def populate_loras( subloras: list[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): - sublora = DummyLoRAManager( - layer_weights.device).init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] + sublora = DummyLoRAManager(layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[ + :, (sublora_len * i) : (sublora_len * (i + 1)) + ] sublora.optimize() subloras.append(sublora) - lora = PackedLoRALayerWeights.pack( - subloras) if repeats > 1 else subloras[0] + lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, @@ -191,7 +208,7 @@ def create_random_inputs( input_size: tuple[int, ...], input_range: tuple[float, float], input_type: torch.dtype = torch.int, - device: torch.device = "cuda" + device: torch.device = "cuda", ) -> tuple[list[torch.Tensor], list[int], list[int]]: """Creates random inputs. @@ -213,14 +230,15 @@ def create_random_inputs( for _ in range(num_inputs): if input_type == torch.int: inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device=device)) + torch.randint( + low=int(low), high=int(high), size=input_size, device=device + ) + ) else: inputs.append( - torch.rand(size=input_size, dtype=input_type, device=device) * - high + low) + torch.rand(size=input_size, dtype=input_type, device=device) * high + + low + ) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] @@ -258,9 +276,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -286,15 +304,18 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) @@ -306,15 +327,12 @@ def create_random_embedding_layer(): input_, lora.lora_a, ) - result += (after_a @ lora.lora_b) + result += after_a @ lora.lora_b expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -324,24 +342,24 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -351,9 +369,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size, stage) -> None: - +def test_embeddings_with_new_embeddings( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -361,9 +379,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -373,12 +391,12 @@ def create_random_embedding_layer(): expanded_embedding = VocabParallelEmbedding( vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=vocab_size) + org_num_embeddings=vocab_size, + ) expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place - lora_embedding = VocabParallelEmbeddingWithLoRA( - deepcopy(expanded_embedding)) + lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) return expanded_embedding, lora_embedding @@ -392,7 +410,8 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size) + ), generate_embeddings_tensor=256, ) @@ -410,52 +429,53 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): embedding_id = lora_id - 1 input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) original_input_[-1] = vocab_size - input_[-2] = vocab_size + ( - (embedding_id + 1) * embeddings_tensor_len - 1) + input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - expanded_embedding.weight[vocab_size:vocab_size + - (embeddings_tensor_len * - max_loras)] = torch.cat(embeddings_tensors) + expanded_embedding.weight[ + vocab_size : vocab_size + (embeddings_tensor_len * max_loras) + ] = torch.cat(embeddings_tensors) lora_result = lora_embedding(torch.cat(original_inputs)) expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): lora = lora_dict[lora_id] result = expanded_embedding(input_) after_a = F.embedding( original_input_, lora.lora_a, ) - result += (after_a @ lora.lora_b) + result += after_a @ lora.lora_b expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -465,24 +485,24 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) + device=device, + ) original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -490,9 +510,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) -def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, - stage) -> None: - +def test_lm_head_logits_processor( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -500,22 +520,25 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, - params_dtype=torch.float16) + linear = ParallelLMHead( + vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16, + ) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size + ) lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device, - None) + logits_processor, 1024, linear.weight.dtype, linear.weight.device, None + ) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor @@ -542,10 +565,9 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -556,25 +578,24 @@ def _pretest(): input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=linear, - embedding_bias=None) + hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None + ) original_lm_head = deepcopy(linear) - linear.weight[logits_processor. - org_vocab_size:logits_processor.org_vocab_size + - embeddings_tensor_len] = embeddings_tensor + linear.weight[ + logits_processor.org_vocab_size : logits_processor.org_vocab_size + + embeddings_tensor_len + ] = embeddings_tensor - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) + logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = logits_processor._get_logits(hidden_states=input_, - lm_head=linear, - embedding_bias=None) - result[:, vocab_size + embeddings_tensor_len:] = float("-inf") + result = logits_processor._get_logits( + hidden_states=input_, lm_head=linear, embedding_bias=None + ) + result[:, vocab_size + embeddings_tensor_len :] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -591,10 +612,9 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -606,17 +626,16 @@ def _pretest(): lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None)[:, :vocab_size] + embedding_bias=None, + )[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None) + embedding_bias=None, + ) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -624,9 +643,7 @@ def _pretest(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: - +def test_linear_replicated(dist_init, num_loras, device, stage, bias_enabled) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,23 +651,25 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + bias_enabled=bias_enabled, + ) def create_random_linear_replicated_layer(): - - linear = ReplicatedLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: @@ -676,10 +695,9 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -699,10 +717,7 @@ def create_random_linear_replicated_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -715,22 +730,19 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -740,9 +752,9 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: - +def test_linear_parallel( + dist_init, num_loras, orientation, fully_shard, device, stage, bias_enabled +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,33 +762,42 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + bias_enabled=bias_enabled, + ) def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = RowParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard - else RowParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + RowParallelLinearWithLoRA(linear) + if not fully_shard + else RowParallelLinearWithShardedLoRA(linear) + ) else: - linear = ColumnParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ColumnParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (ColumnParallelLinearWithLoRA(linear) - if not fully_shard else - ColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + ColumnParallelLinearWithLoRA(linear) + if not fully_shard + else ColumnParallelLinearWithShardedLoRA(linear) + ) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: @@ -802,10 +823,9 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -825,10 +845,7 @@ def create_random_linear_parallel_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -841,22 +858,19 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -866,9 +880,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: - +def test_column_parallel_packed( + dist_init, num_loras, repeats, fully_shard, device, stage, bias_enabled +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,41 +890,45 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + bias_enabled=bias_enabled, + ) def create_column_parallel_packed_layer(): if repeats == 2: - linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False, - params_dtype=torch.float16) + linear = MergedColumnParallelLinear( + 4096, [4096] * repeats, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedColumnParallelLinearWithLoRA(linear) - if not fully_shard else - MergedColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard + else MergedColumnParallelLinearWithShardedLoRA(linear) + ) elif repeats == 3: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLoRA(linear) - if not fully_shard else - MergedQKVParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedQKVParallelLinearWithLoRA(linear) + if not fully_shard + else MergedQKVParallelLinearWithShardedLoRA(linear) + ) else: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLoRA( - linear - ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear) + lora_linear = ( + QKVParallelLinearWithLoRA(linear) + if not fully_shard + else QKVParallelLinearWithShardedLoRA(linear) + ) @dataclass class FakeConfig: @@ -919,11 +937,15 @@ class FakeConfig: num_attention_heads = 32 n_slices = repeats - lora_linear.create_lora_weights(max_loras, - lora_config, - model_config=FakeConfig()) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == n_slices) + lora_linear.create_lora_weights( + max_loras, lora_config, model_config=FakeConfig() + ) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == n_slices + ) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: @@ -951,10 +973,9 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -971,17 +992,14 @@ class FakeConfig: result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) + result[ + :, sublora.lora_b.shape[1] * i : sublora.lora_b.shape[1] * (i + 1) + ] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) @@ -992,10 +1010,9 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -1009,15 +1026,13 @@ class FakeConfig: expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize( - "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))) + "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)) +) def test_vocab_parallel_embedding_indices(tp_size, seed): random.seed(seed) vocab_size = random.randint(4000, 64000) @@ -1035,20 +1050,24 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): token_ids: list[int] = [] for tp_rank in range(tp_size): - with patch( + with ( + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank - ), patch( + return_value=tp_rank, + ), + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size): + return_value=tp_size, + ), + ): vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size, 1, org_num_embeddings=org_vocab_size + ) vocab_size_padded = vocab_embedding.num_embeddings_padded shard_indices = vocab_embedding.shard_indices # Assert that the ranges are contiguous assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert (shard_indices.added_vocab_start_index == - last_added_vocab_end_index) + assert shard_indices.added_vocab_start_index == last_added_vocab_end_index # Ensure that we are not exceeding the vocab size computed_vocab_size += shard_indices.num_elements_padded @@ -1057,22 +1076,39 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): # Ensure that the ranges are not overlapping all_org_tokens.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) all_added_tokens.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) token_ids.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_org_elements_padded - - shard_indices.num_org_elements)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) + token_ids.extend( + [-1] + * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements) + ) token_ids.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_added_elements_padded - - shard_indices.num_added_elements)) + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) + token_ids.extend( + [-1] + * ( + shard_indices.num_added_elements_padded + - shard_indices.num_added_elements + ) + ) last_org_vocab_end_index = shard_indices.org_vocab_end_index last_added_vocab_end_index = shard_indices.added_vocab_end_index @@ -1100,130 +1136,165 @@ def test_get_masked_input_and_mask(): x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # base tp 1 case, no padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0, + ) assert torch.equal(x, modified_x) # tp 2 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]) + ) # tp 4 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=0) + num_org_vocab_padding=0, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]) + ) # base tp 1 case, with padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]) + ) # tp 2 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]) + ) # tp 4 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=2) + num_org_vocab_padding=2, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]) + ) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index bebf44b6dfd7..d7824178626f 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -27,27 +27,28 @@ " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501 ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - tensorizer_config_dict: Union[dict, None] = None) -> list[str]: +def do_sample( + llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: Union[dict, None] = None, +) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501 ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=256, - skip_special_tokens=False, - stop=["[/assistant]"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"] + ) if tensorizer_config_dict is not None: outputs = llm.generate( @@ -57,14 +58,19 @@ def do_sample(llm: vllm.LLM, str(lora_id), lora_id, lora_path, - tensorizer_config_dict=tensorizer_config_dict) - if lora_id else None) + tensorizer_config_dict=tensorizer_config_dict, + ) + if lora_id + else None, + ) else: outputs = llm.generate( prompts, sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + if lora_id + else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -75,53 +81,72 @@ def do_sample(llm: vllm.LLM, return generated_texts -def generate_and_test(llm, - sql_lora_files, - tensorizer_config_dict: Union[dict, None] = None): +def generate_and_test( + llm, sql_lora_files, tensorizer_config_dict: Union[dict, None] = None +): print("lora adapter created") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0, + ) + == EXPECTED_NO_LORA_OUTPUT + ) print("lora 1") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1, + ) + == EXPECTED_LORA_OUTPUT + ) print("no lora") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0, + ) + == EXPECTED_NO_LORA_OUTPUT + ) print("lora 2") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=2) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2, + ) + == EXPECTED_LORA_OUTPUT + ) print("removing lora") @create_new_process_for_each_test() def test_llama_lora(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, max_loras=4, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, enable_lora=True, @@ -136,7 +161,6 @@ def test_llama_lora_tp4(sql_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, enable_lora=True, @@ -151,9 +175,9 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @multi_gpu_test(num_gpus=2) @create_new_process_for_each_test() -def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, - sql_lora_huggingface_id): - +def test_tp2_serialize_and_deserialize_lora( + tmp_path, sql_lora_files, sql_lora_huggingface_id +): # Run the tensorizing of the LoRA adapter and the model in a subprocess # to guarantee cleanup @@ -164,17 +188,28 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, lora_path = sql_lora_huggingface_id suffix = "test" try: - result = subprocess.run([ - sys.executable, - f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", - MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", - str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix, "--serialization-kwargs", - '{"limit_cpu_concurrency": 4}' - ], - check=True, - capture_output=True, - text=True) + result = subprocess.run( + [ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", + "--model", + MODEL_PATH, + "--lora-path", + lora_path, + "--tensor-parallel-size", + str(tp_size), + "serialize", + "--serialized-directory", + str(tmp_path), + "--suffix", + suffix, + "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}', + ], + check=True, + capture_output=True, + text=True, + ) except subprocess.CalledProcessError as e: print("Tensorizing failed.") print("STDOUT:\n", e.stdout) @@ -186,25 +221,37 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, model_uri = tmp_path / "vllm" / model_ref / suffix / model_name tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) - loaded_vllm_model = LLM(model=model_ref, - load_format="tensorizer", - enable_lora=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config, - max_num_seqs=13, - tensor_parallel_size=2, - max_loras=2) + loaded_vllm_model = LLM( + model=model_ref, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2, + ) tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") - assert do_sample(loaded_vllm_model, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert ( + do_sample( + loaded_vllm_model, + sql_lora_files, + tensorizer_config_dict=tc_as_dict, + lora_id=0, + ) + == EXPECTED_NO_LORA_OUTPUT + ) print("lora 1") - assert do_sample(loaded_vllm_model, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + loaded_vllm_model, + sql_lora_files, + tensorizer_config_dict=tc_as_dict, + lora_id=1, + ) + == EXPECTED_LORA_OUTPUT + ) diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py index 01bc102bd112..ed23d8278488 100644 --- a/tests/lora/test_lora_allowed_token_ids.py +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -3,16 +3,16 @@ import pytest -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - VllmConfig) +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, VllmConfig from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.engine.processor import Processor -def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, - sql_lora_files): +def test_allowed_token_ids_with_lora_vocab( + llama_2_7b_base_huggingface_id, sql_lora_files +): """ Test that we properly resolve the range of allowed token ids for lora adapters that define additional tokens. @@ -36,7 +36,8 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + lora_config=vllm_config.lora_config, + ) processor = Processor(vllm_config, tokenizer) lora_request = LoRARequest("1", 1, str(sql_lora_files)) @@ -49,7 +50,8 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, request_id, prompt, params=SamplingParams(allowed_token_ids=lora_token_ids), - lora_request=lora_request) + lora_request=lora_request, + ) # tokens in the base model should not raise an error base_token_ids = [1000, 1001, 1002, 1003] @@ -57,7 +59,8 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, request_id, prompt, params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) + lora_request=lora_request, + ) # tokens not in the lora adapter should raise an error invalid_token_ids = [35000, 35001, 35002, 35003] @@ -66,7 +69,8 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, request_id, prompt, params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) + lora_request=lora_request, + ) # tokens in the lora adapter with no lora request should raise an error with pytest.raises(ValueError): @@ -78,7 +82,8 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, def test_allowed_token_ids_with_lora_adapter_no_vocab( - qwen25vl_base_huggingface_id, qwen25vl_lora_files): + qwen25vl_base_huggingface_id, qwen25vl_lora_files +): """ Test that we properly resolve the range of allowed token ids for lora adapters that do not define additional tokens. @@ -102,7 +107,8 @@ def test_allowed_token_ids_with_lora_adapter_no_vocab( tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + lora_config=vllm_config.lora_config, + ) processor = Processor(vllm_config, tokenizer) lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) @@ -115,7 +121,8 @@ def test_allowed_token_ids_with_lora_adapter_no_vocab( request_id, prompt, params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) + lora_request=lora_request, + ) # tokens in the base model with no lora request should not raise an error base_token_ids = [1000, 1001, 1002, 1003] @@ -132,4 +139,5 @@ def test_allowed_token_ids_with_lora_adapter_no_vocab( request_id, prompt, params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) + lora_request=lora_request, + ) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index ebc0f26378d2..2219d470e91a 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -8,9 +8,7 @@ from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper -lora_lst = [ - "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" -] +lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"] BAICHUAN_LORA_MODULES = [ "W_pack", "o_proj", @@ -37,8 +35,9 @@ def test_load_checkpoints( else: expected_lora_modules.append(module) if lora_name == "baichuan7B": - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) # For the baichuan7B model, load it's LoRA, # and the test should pass. LoRAModel.from_local_checkpoint( @@ -48,13 +47,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero": # Test that the target_modules contain prefix # such as "model.layers.0.self_atten.W_pack", and # the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_zero_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_zero_lora_files, expected_lora_modules, @@ -62,12 +63,14 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero-regex": # Test that the `target_modules` in the form of regular expressions, # such as `model\\..*(W_pack|o_proj)`, and the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_regex_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_regex_lora_files, expected_lora_modules, @@ -75,13 +78,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 - peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + chatglm3_lora_files, max_position_embeddings=4096 + ) with pytest.raises(ValueError, match=expected_error): LoRAModel.from_local_checkpoint( chatglm3_lora_files, @@ -90,11 +95,11 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) def test_lora_weights_mapping(baichuan_lora_files): - packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules @@ -113,8 +118,9 @@ def test_lora_weights_mapping(baichuan_lora_files): ".layers.": ".baichuan_layers.", }, ) - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 50c60341f0d8..bc90a88dc226 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -3,12 +3,14 @@ """ Script to test add_lora, remove_lora, pin_lora, list_loras functions. """ + import pytest from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.lora.request import LoRARequest MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -17,23 +19,24 @@ def make_lora_request(lora_id: int): - return LoRARequest(lora_name=f"{lora_id}", - lora_int_id=lora_id, - lora_path=LORA_MODULE_PATH) + return LoRARequest( + lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=LORA_MODULE_PATH + ) def test_lora_functions_sync(): - max_loras = 4 # Create engine in eager-mode. Due to high max_loras, the CI can # OOM during cuda-graph capture. - engine_args = EngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = EngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) llm = LLMEngine.from_engine_args(engine_args) @@ -70,15 +73,16 @@ def run_check(fn, args, expected: list): @pytest.mark.asyncio async def test_lora_functions_async(): - max_loras = 4 - engine_args = AsyncEngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = AsyncEngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) async def run_check(fn, args, expected: list): await fn(args) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index b46d81f1651a..7d20faef541a 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -11,8 +11,12 @@ # Provide absolute path and huggingface lora ids lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] LLAMA_LORA_MODULES = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", ] @@ -40,7 +44,8 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) # Assertions to ensure the model is loaded correctly assert lora_model is not None, "LoRAModel is not loaded correctly" diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8f8a27006cf6..77317d63dc21 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -9,16 +9,21 @@ from torch import nn from vllm.config import LoRAConfig -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - RowParallelLinearWithLoRA) +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, +) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager) +from vllm.lora.models import ( + LoRAMapping, + LoRAModel, + LoRAModelManager, + LRUCacheLoRAModelManager, +) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, - WorkerLoRAManager) +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager from vllm.platforms import current_platform EMBEDDING_MODULES = { @@ -28,9 +33,11 @@ EMBEDDING_PADDING_MODULES = ["lm_head"] -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) DEFAULT_DTYPE = torch.get_default_dtype() @@ -42,19 +49,20 @@ def use_v0_only(monkeypatch: pytest.MonkeyPatch): LoRAModelManager it is okay to just test V0. """ with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') + m.setenv("VLLM_USE_V1", "0") yield @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): - tensors = load_file( - os.path.join(sql_lora_files, "adapter_model.safetensors")) + tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors")) + os.path.join(sql_lora_files, "new_embeddings.safetensors") + ) - peft_helper = PEFTHelper.from_local_dir(sql_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + sql_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_lora_tensors( 1, tensors, @@ -62,7 +70,8 @@ def test_from_lora_tensors(sql_lora_files, device): device=device, embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, - embedding_padding_modules=EMBEDDING_PADDING_MODULES) + embedding_padding_modules=EMBEDDING_PADDING_MODULES, + ) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -71,22 +80,27 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.lora_b is not None assert lora.lora_a.device == torch.device(device) assert lora.lora_b.device == torch.device(device) - assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] - ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == lora.lora_b.shape[0], ( + f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + ) assert lora.lora_a.shape[1] == 8 embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None) + (k for k in EMBEDDING_MODULES if k in module_name), None + ) if embeddings_module: assert torch.equal( lora.embeddings_tensor, new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device)) + device=lora.embeddings_tensor.device + ), + ) else: assert lora.embeddings_tensor is None -def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], - device: torch.device) -> LoRAModel: +def create_lora( + lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device +) -> LoRAModel: loras: dict[str, LoRALayerWeights] = {} for name in sub_modules: w = model.get_submodule(name).weight @@ -118,8 +132,7 @@ def create_packed_lora( 8, 16, torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0] // len(replaced_module_names)], - device=device), + torch.rand([8, w.shape[0] // len(replaced_module_names)], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -127,42 +140,42 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( - model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=8, - max_loras=8, - lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) model = manager.model - assert isinstance(model.get_submodule("dense1"), - ColumnParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense1"), - ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA + ) assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense2"), - RowParallelLinearWithLoRA) + assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -212,24 +225,21 @@ def test_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -305,27 +315,22 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora4 = create_lora(4, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) @@ -430,66 +435,83 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) +def test_lru_cache_worker_adapter_manager( + llama_2_7b_model_extra_embeddings, sql_lora_files, device +): + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, + 2, + llama_2_7b_model_extra_embeddings.unpadded_vocab_size + - lora_config.lora_extra_vocab_size, + lora_config, + device, + EMBEDDING_MODULES, + EMBEDDING_PADDING_MODULES, + ) + worker_adapter_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 @@ -498,78 +520,97 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_worker_adapter_manager( + llama_2_7b_model_extra_embeddings, sql_lora_files, device +): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) worker_adapter_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, + 2, + llama_2_7b_model_extra_embeddings.unpadded_vocab_size + - lora_config.lora_extra_vocab_size, + lora_config, + device, + EMBEDDING_MODULES, + EMBEDDING_PADDING_MODULES, + ) + worker_adapter_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 @@ -577,17 +618,19 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) @@ -598,7 +641,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): model, module_name="gate_up_proj", replaced_module_names=["gate_proj", "up_proj"], - device=device) + device=device, + ) model_lora1 = create_packed_lora( 2, model, @@ -608,19 +652,21 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): empty_replaced_module_name="gate_proj", ) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) model = manager.model - assert isinstance(model.get_submodule("gate_up_proj"), - MergedColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA + ) # Verify packed lora is correct model_lora_clone = model_lora.clone(1) model_lora_clone1 = model_lora1.clone(1) @@ -633,21 +679,27 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) - torch.testing.assert_close(packed_lora.lora_a[0], - model_lora_clone.get_lora("gate_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[0], - model_lora_clone.get_lora("gate_proj").lora_b) - torch.testing.assert_close(packed_lora.lora_a[1], - model_lora_clone.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[1], - model_lora_clone.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b + ) + torch.testing.assert_close( + packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b + ) packed_lora1 = model_lora1.get_lora("gate_up_proj") assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None - torch.testing.assert_close(packed_lora1.lora_a[1], - model_lora_clone1.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora1.lora_b[1], - model_lora_clone1.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b + ) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 99fe951bbf07..ce98fe2f8613 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -15,7 +15,8 @@ PROMPT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) IMAGE_ASSETS = [ ImageAsset("stop_sign"), @@ -34,18 +35,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: stop_token_ids=[128001, 128009], # eos_id, eot_id ) - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] + inputs = [ + { + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in IMAGE_ASSETS + ] outputs = llm.generate( inputs, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, ) # Print the outputs. generated_texts: list[str] = [] @@ -58,7 +59,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -68,10 +70,7 @@ def test_minicpmv_lora(minicpmv_lora_files): max_lora_rank=8, enforce_eager=True, max_model_len=2048, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -82,11 +81,13 @@ def test_minicpmv_lora(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output2[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -96,10 +97,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -107,11 +105,13 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -122,10 +122,7 @@ def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, - limit_mm_per_prompt={ - "image": 1, - "video": 0 - }, + limit_mm_per_prompt={"image": 1, "video": 0}, fully_sharded_loras=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 0ea07793311c..f80b496d1b34 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -11,15 +11,15 @@ MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: list[str]) -> list[str]: - +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, prompts: list[str] +) -> list[str]: sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -33,8 +33,11 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count( - ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): + if ( + torch.cuda.device_count() < tp_size + and tp_size > 1 + and current_platform.is_cuda_alike() + ): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ @@ -58,7 +61,11 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501 "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501 ] - assert do_sample(llm, mixtral_lora_files, lora_id=1, - prompts=prompts) == expected_lora_output - assert do_sample(llm, mixtral_lora_files, lora_id=2, - prompts=prompts) == expected_lora_output + assert ( + do_sample(llm, mixtral_lora_files, lora_id=1, prompts=prompts) + == expected_lora_output + ) + assert ( + do_sample(llm, mixtral_lora_files, lora_id=2, prompts=prompts) + == expected_lora_output + ) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index f16589e06b2d..d05c1bd22c9f 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -13,34 +13,27 @@ ERROR_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), ( "test_bias", - { - "bias": "all" - }, + {"bias": "all"}, "Adapter bias cannot be used without bias_enabled", ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): - peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + long_context_lora_files_16k_1, max_position_embeddings=4096 + ) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) peft_helper.validate_legal(lora_config) assert peft_helper.r == 8 @@ -59,8 +52,8 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): assert peft_helper.context_length == 16384 assert peft_helper.vllm_max_position_embeddings == 4096 assert peft_helper.vllm_long_context_scaling_factor == float( - math.ceil(peft_helper.context_length / - peft_helper.vllm_max_position_embeddings)) + math.ceil(peft_helper.context_length / peft_helper.vllm_max_position_embeddings) + ) # test RSLoRA rslora_config = dict(use_rslora=True) test_dir = tmp_path / "test_rslora" @@ -77,8 +70,7 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): with open(config_path, "w") as f: json.dump(adapter_config, f) - peft_helper = PEFTHelper.from_local_dir(test_dir, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir(test_dir, max_position_embeddings=4096) peft_helper.validate_legal(lora_config) scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 @@ -109,4 +101,5 @@ def test_peft_helper_error( # Test loading the adapter with pytest.raises(ValueError, match=expected_error): PEFTHelper.from_local_dir( - test_dir, max_position_embeddings=4096).validate_legal(lora_config) + test_dir, max_position_embeddings=4096 + ).validate_legal(lora_config) diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 3090941e6367..ebc027aab384 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -12,30 +12,23 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), + sql_prompt="Which catalog publisher has published the most catalogs?", + context="CREATE TABLE catalogs (catalog_publisher VARCHAR);", + ), PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + sql_prompt="Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + context="CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);", # noqa: E501 ), PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 + sql_prompt="How many marine species are found in the Southern Ocean?", # noqa: E501 + context="CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));", # noqa: E501 ), ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64, stop="### End") outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, ) # Print the outputs. generated_texts: list[str] = [] @@ -50,12 +43,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=2, + enforce_eager=True, + enable_chunked_prefill=True, + ) expected_lora_output = [ "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 14fa79ae5b44..e4df9751077d 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -21,11 +21,18 @@ def reset_device(reset_default_device): # Utility shrink and expand operations used as reference implementations. def sgmv_shrink_for_nslices( - nslices: int, inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, - num_tokens: int, scaling: float): + nslices: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + scaling: float, +): """ Wrapper around torch_ops.sgmv_shrink that handles any nslices. """ @@ -44,15 +51,20 @@ def sgmv_shrink_for_nslices( ) -def sgmv_expand_for_nslices(nslices: int, hidden_size: int, - inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], - out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, - max_seq_length: int, num_tokens: int, - add_inputs: bool) -> None: +def sgmv_expand_for_nslices( + nslices: int, + hidden_size: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + add_inputs: bool, +) -> None: """ Wrapper around torch_ops.sgmv_expand that handles any nslices. """ @@ -94,10 +106,17 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, _dict_lock = Lock() -def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - scaling: float): +def check_lora_shrink_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + scaling: float, +): """ Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink kernels. @@ -116,14 +135,19 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) ref_out_tensor = data.ref_out_tensor @@ -154,10 +178,17 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, assert_close(out_tensor, ref_out_tensor) -def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - add_inputs: bool): +def check_lora_expand_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + add_inputs: bool, +): """ Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand kernels. @@ -177,14 +208,19 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) # Setup output tensors @@ -194,21 +230,25 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, with _dict_lock: # lora_expand kernel _LORA_B_PTR_DICT.clear() - triton_ops.lora_expand(data.inputs_tensor, - data.lora_weights, - out_tensor, - *lora_meta.meta_args(token_nums=token_nums), - offset_start=0, - add_inputs=add_inputs) + triton_ops.lora_expand( + data.inputs_tensor, + data.lora_weights, + out_tensor, + *lora_meta.meta_args(token_nums=token_nums), + offset_start=0, + add_inputs=add_inputs, + ) # Reference - sgmv_expand_for_nslices(nslices, - hidden_size, - data.inputs_tensor, - data.lora_weights, - ref_out_tensor, - *sgmv_meta_args, - add_inputs=add_inputs) + sgmv_expand_for_nslices( + nslices, + hidden_size, + data.inputs_tensor, + data.lora_weights, + ref_out_tensor, + *sgmv_meta_args, + add_inputs=add_inputs, + ) assert_close(out_tensor, ref_out_tensor) @@ -299,7 +339,7 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, 128000, 128256, ] -#The size of TP +# The size of TP divisibility = [1, 2, 8, 16, 64] all_hidden_size = [] @@ -331,10 +371,10 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, SEED = [0] -@pytest.mark.parametrize("batches", test_params['batches']) -@pytest.mark.parametrize("num_loras", test_params['num_loras']) -@pytest.mark.parametrize("rank", test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("batches", test_params["batches"]) +@pytest.mark.parametrize("num_loras", test_params["num_loras"]) +@pytest.mark.parametrize("rank", test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -358,31 +398,35 @@ def test_kernels( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) - - -@pytest.mark.parametrize("batches", hs_test_params['batches']) -@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) -@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) + + +@pytest.mark.parametrize("batches", hs_test_params["batches"]) +@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"]) +@pytest.mark.parametrize("rank", hs_test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -406,22 +450,26 @@ def test_kernels_hidden_size( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index caa31fdb0e73..6b1180ea68d0 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -20,28 +20,27 @@ class ModelWithQuantization: MODELS: list[ModelWithQuantization] -#AWQ quantization is currently not supported in ROCm. +# AWQ quantization is currently not supported in ROCm. if current_platform.is_rocm(): MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] else: MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="awq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", quantization="awq" + ), ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - max_tokens: int = 256) -> list[str]: +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, max_tokens: int = 256 +) -> list[str]: raw_prompts = [ "Give me an orange-ish brown color", "Give me a neon pink color", @@ -52,14 +51,14 @@ def format_prompt_tuples(prompt): prompts = [format_prompt_tuples(p) for p in raw_prompts] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=max_tokens, - stop=["<|im_end|>"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=max_tokens, stop=["<|im_end|>"] + ) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -72,22 +71,22 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("model", MODELS) def test_quant_model_lora(tinyllama_lora_files, model): - llm = vllm.LLM( model=model.model_path, enable_lora=True, max_num_seqs=16, max_loras=4, max_model_len=400, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) if model.quantization is None: expected_no_lora_output = [ "Here are some examples of orange-brown colors", - "I'm sorry, I don't have" + "I'm sorry, I don't have", ] expected_lora_output = [ "#ff8050", @@ -115,43 +114,31 @@ def test_quant_model_lora(tinyllama_lora_files, model): def expect_match(output, expected_output): # HACK: GPTQ lora outputs are just incredibly unstable. # Assert that the outputs changed. - if (model.quantization == "gptq" - and expected_output is expected_lora_output): + if model.quantization == "gptq" and expected_output is expected_lora_output: assert output != expected_no_lora_output for i, o in enumerate(output): - assert o.startswith( - '#'), f"Expected example {i} to start with # but got {o}" + assert o.startswith("#"), ( + f"Expected example {i} to start with # but got {o}" + ) return assert output == expected_output max_tokens = 10 print("lora adapter created") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=0, max_tokens=max_tokens) expect_match(output, expected_no_lora_output) print("lora 1") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=1, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=1, max_tokens=max_tokens) expect_match(output, expected_lora_output) print("no lora") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=0, max_tokens=max_tokens) expect_match(output, expected_no_lora_output) print("lora 2") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=2, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=2, max_tokens=max_tokens) expect_match(output, expected_lora_output) print("removing lora") @@ -161,8 +148,7 @@ def expect_match(output, expected_output): @pytest.mark.parametrize("model", MODELS) -def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, - model): +def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model): if num_gpus_available < 2: pytest.skip(f"Not enough GPUs for tensor parallelism {2}") if model.quantization == "gptq": @@ -172,10 +158,11 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, enable_lora=True, max_num_seqs=16, max_loras=4, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 @@ -187,9 +174,10 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, max_num_seqs=16, max_loras=4, tensor_parallel_size=2, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 604bb307b889..9d3e2b265e3a 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -39,7 +39,8 @@ class Qwen2VLTester: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") + "<|im_start|>assistant\n" + ) def __init__(self, config: TestConfig): self.config = config @@ -58,68 +59,68 @@ def _initialize_llm(self) -> vllm.LLM: max_model_len=self.config.max_model_len, ) - def run_test(self, - images: list[ImageAsset], - expected_outputs: list[str], - lora_id: Optional[int] = None, - temperature: float = 0, - max_tokens: int = 5): - + def run_test( + self, + images: list[ImageAsset], + expected_outputs: list[str], + lora_id: Optional[int] = None, + temperature: float = 0, + max_tokens: int = 5, + ): sampling_params = vllm.SamplingParams( temperature=temperature, max_tokens=max_tokens, ) - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.generate(inputs, - sampling_params, - lora_request=lora_request) - generated_texts = [ - output.outputs[0].text.strip() for output in outputs + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images ] + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request) + generated_texts = [output.outputs[0].text.strip() for output in outputs] + # Validate outputs for generated, expected in zip(generated_texts, expected_outputs): - assert expected.startswith( - generated), f"Generated text {generated} doesn't " + assert expected.startswith(generated), ( + f"Generated text {generated} doesn't " + ) f"match expected pattern {expected}" - def run_beam_search_test(self, - images: list[ImageAsset], - expected_outputs: list[list[str]], - lora_id: Optional[int] = None, - temperature: float = 0, - beam_width: int = 2, - max_tokens: int = 5): - - beam_search_params = BeamSearchParams(beam_width=beam_width, - max_tokens=max_tokens, - temperature=temperature) - - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.beam_search(inputs, - beam_search_params, - lora_request=lora_request) + def run_beam_search_test( + self, + images: list[ImageAsset], + expected_outputs: list[list[str]], + lora_id: Optional[int] = None, + temperature: float = 0, + beam_width: int = 2, + max_tokens: int = 5, + ): + beam_search_params = BeamSearchParams( + beam_width=beam_width, max_tokens=max_tokens, temperature=temperature + ) + + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images + ] + + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.beam_search( + inputs, beam_search_params, lora_request=lora_request + ) for output_obj, expected_outs in zip(outputs, expected_outputs): output_texts = [seq.text for seq in output_obj.sequences] - assert output_texts == expected_outs, \ - f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501 + assert output_texts == expected_outs, ( + f"Generated texts {output_texts} do not match expected {expected_outs}" + ) # noqa: E501 TEST_IMAGES = [ @@ -146,27 +147,25 @@ def run_beam_search_test(self, @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA through beam search.""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs @@ -178,7 +177,8 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): tester.run_beam_search_test( [ImageAsset("cherry_blossom")], expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS, - lora_id=lora_id) + lora_id=lora_id, + ) @pytest.mark.xfail( @@ -191,12 +191,9 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): ) def test_qwen25vl_lora(qwen25vl_lora_files): """Test Qwen 2.5 VL model with LoRA""" - config = TestConfig(model_path=QWEN25VL_MODEL_PATH, - lora_path=qwen25vl_lora_files) + config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py index 6c93e577611f..c70e58a375c7 100644 --- a/tests/lora/test_resolver.py +++ b/tests/lora/test_resolver.py @@ -12,13 +12,15 @@ class DummyLoRAResolver(LoRAResolver): """A dummy LoRA resolver for testing.""" - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test_lora": return LoRARequest( lora_name=lora_name, lora_path=f"/dummy/path/{base_model_name}/{lora_name}", - lora_int_id=abs(hash(lora_name))) + lora_int_id=abs(hash(lora_name)), + ) return None @@ -70,6 +72,5 @@ async def test_dummy_resolver_resolve(): assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}" # Test failed resolution - result = await dummy_resolver.resolve_lora(base_model_name, - "nonexistent_lora") + result = await dummy_resolver.resolve_lora(base_model_name, "nonexistent_lora") assert result is None diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index 6cfdaf50d33c..740da5e35529 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -22,22 +22,25 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): ) lora_request = LoRARequest("1", 1, sql_lora_files) assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( - prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer_group.encode_async( - prompt="prompt", lora_request=lora_request) - assert isinstance(tokenizer_group.get_lora_tokenizer(None), - PreTrainedTokenizerBase) + prompt="prompt", lora_request=lora_request + ) + assert reference_tokenizer.encode("prompt") == await tokenizer_group.encode_async( + prompt="prompt", lora_request=lora_request + ) + assert isinstance(tokenizer_group.get_lora_tokenizer(None), PreTrainedTokenizerBase) assert tokenizer_group.get_lora_tokenizer( - None) == await tokenizer_group.get_lora_tokenizer_async(None) + None + ) == await tokenizer_group.get_lora_tokenizer_async(None) - assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) + assert isinstance( + tokenizer_group.get_lora_tokenizer(lora_request), PreTrainedTokenizerBase + ) assert tokenizer_group.get_lora_tokenizer( - lora_request) != tokenizer_group.get_lora_tokenizer(None) + lora_request + ) != tokenizer_group.get_lora_tokenizer(None) assert tokenizer_group.get_lora_tokenizer( - lora_request) == await tokenizer_group.get_lora_tokenizer_async( - lora_request) + lora_request + ) == await tokenizer_group.get_lora_tokenizer_async(lora_request) def test_get_lora_tokenizer(sql_lora_files, tmp_path): @@ -66,7 +69,6 @@ def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): max_input_length=None, ) if enable_lora: - assert tokenizer_group.lora_tokenizers.capacity == max( - max_num_seqs, max_loras) + assert tokenizer_group.lora_tokenizers.capacity == max(max_num_seqs, max_loras) else: assert tokenizer_group.lora_tokenizers.capacity == 0 diff --git a/tests/lora/test_transformers_model.py b/tests/lora/test_transformers_model.py index 5065a2fb7164..da924485c3e2 100644 --- a/tests/lora/test_transformers_model.py +++ b/tests/lora/test_transformers_model.py @@ -24,20 +24,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query="What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 ), PROMPT_TEMPLATE.format( - query= - "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + query="What are all distinct countries where singers above age 20 are from?" # noqa: E501 ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -49,13 +47,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_ilama_lora(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -65,20 +65,23 @@ def test_ilama_lora(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -88,20 +91,23 @@ def test_ilama_lora_tp4(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index b343bef0a920..aed91d98ddbd 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,8 +9,11 @@ from huggingface_hub.utils import HfHubHTTPError from torch import nn -from vllm.lora.utils import (get_adapter_absolute_path, - parse_fine_tuned_lora_name, replace_submodule) +from vllm.lora.utils import ( + get_adapter_absolute_path, + parse_fine_tuned_lora_name, + replace_submodule, +) from vllm.model_executor.models.utils import WeightsMapper @@ -24,10 +27,12 @@ class LoRANameParserTestConfig(NamedTuple): def test_parse_fine_tuned_lora_name_valid(): fixture = [ - LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", - "lm_head", True, False), - LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", - "lm_head", False, False), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_A.weight", "lm_head", True, False + ), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_B.weight", "lm_head", False, False + ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", @@ -71,7 +76,8 @@ def test_parse_fine_tuned_lora_name_valid(): True, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", @@ -79,7 +85,8 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_A.weight", @@ -87,7 +94,8 @@ def test_parse_fine_tuned_lora_name_valid(): True, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_B.weight", @@ -95,12 +103,14 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), ] for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: - assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) + assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( + name, weights_mapper + ) def test_parse_fine_tuned_lora_name_invalid(): @@ -115,22 +125,28 @@ def test_parse_fine_tuned_lora_name_invalid(): def test_replace_submodule(): model = nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(764, 100)), - ("act1", nn.ReLU()), - ("dense2", nn.Linear(100, 50)), - ( - "seq1", - nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(100, 10)), - ("dense2", nn.Linear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", nn.Linear(50, 10)), - ("outact", nn.Sigmoid()), - ])) + OrderedDict( + [ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict( + [ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ] + ) + ) sigmoid = nn.Sigmoid() @@ -143,52 +159,51 @@ def test_replace_submodule(): # Unit tests for get_adapter_absolute_path -@patch('os.path.isabs') +@patch("os.path.isabs") def test_get_adapter_absolute_path_absolute(mock_isabs): - path = '/absolute/path/to/lora' + path = "/absolute/path/to/lora" mock_isabs.return_value = True assert get_adapter_absolute_path(path) == path -@patch('os.path.expanduser') +@patch("os.path.expanduser") def test_get_adapter_absolute_path_expanduser(mock_expanduser): # Path with ~ that needs to be expanded - path = '~/relative/path/to/lora' - absolute_path = '/home/user/relative/path/to/lora' + path = "~/relative/path/to/lora" + absolute_path = "/home/user/relative/path/to/lora" mock_expanduser.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('os.path.exists') -@patch('os.path.abspath') +@patch("os.path.exists") +@patch("os.path.abspath") def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist): # Relative path that exists locally - path = 'relative/path/to/lora' - absolute_path = '/absolute/path/to/lora' + path = "relative/path/to/lora" + absolute_path = "/absolute/path/to/lora" mock_exist.return_value = True mock_abspath.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download): # Hugging Face model identifier - path = 'org/repo' - absolute_path = '/mock/snapshot/path' + path = "org/repo" + absolute_path = "/mock/snapshot/path" mock_exist.return_value = False mock_snapshot_download.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface_error(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface_error( + mock_exist, mock_snapshot_download +): # Hugging Face model identifier with download error - path = 'org/repo' + path = "org/repo" mock_exist.return_value = False - mock_snapshot_download.side_effect = HfHubHTTPError( - "failed to query model info") + mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") assert get_adapter_absolute_path(path) == path diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 9999c1be54ea..172e8a440d3f 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -8,9 +8,16 @@ from unittest.mock import patch import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker as V1Worker @@ -21,9 +28,9 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - - def set_active_loras(worker: Union[Worker, V1Worker], - lora_requests: list[LoRARequest]): + def set_active_loras( + worker: Union[Worker, V1Worker], lora_requests: list[LoRARequest] + ): lora_mapping = LoRAMapping([], []) if isinstance(worker, Worker): # v0 case @@ -31,7 +38,8 @@ def set_active_loras(worker: Union[Worker, V1Worker], else: # v1 case worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) + lora_requests, lora_mapping + ) worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker @@ -63,9 +71,9 @@ def set_active_loras(worker: Union[Worker, V1Worker], swap_space=0, cache_dtype="auto", ), - lora_config=LoRAConfig(max_lora_rank=8, - max_cpu_loras=NUM_LORAS, - max_loras=NUM_LORAS), + lora_config=LoRAConfig( + max_lora_rank=8, max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS + ), ) worker = worker_cls( vllm_config=vllm_config, @@ -81,23 +89,22 @@ def set_active_loras(worker: Union[Worker, V1Worker], assert worker.list_loras() == set() lora_requests = [ - LoRARequest(str(i + 1), i + 1, sql_lora_files) - for i in range(NUM_LORAS) + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS) ] set_active_loras(worker, lora_requests) assert worker.list_loras() == { - lora_request.lora_int_id - for lora_request in lora_requests + lora_request.lora_int_id for lora_request in lora_requests } for i in range(NUM_LORAS): random.seed(i) - iter_lora_requests = random.choices(lora_requests, - k=random.randint(1, NUM_LORAS)) + iter_lora_requests = random.choices( + lora_requests, k=random.randint(1, NUM_LORAS) + ) random.shuffle(iter_lora_requests) - iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)] + iter_lora_requests = iter_lora_requests[: -random.randint(0, NUM_LORAS)] set_active_loras(worker, lora_requests) assert worker.list_loras().issuperset( - {lora_request.lora_int_id - for lora_request in iter_lora_requests}) + {lora_request.lora_int_id for lora_request in iter_lora_requests} + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc1b0d81955b..841d663a5e3c 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -10,7 +10,6 @@ class DummyLoRAManager: - def __init__(self, device: torch.device = "cuda:0"): super().__init__() self._loras: dict[str, LoRALayerWeights] = {} @@ -33,12 +32,12 @@ def init_random_lora( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([weight.shape[1], rank], - dtype=weight.dtype, - device=self._device), - lora_b=torch.rand([rank, weight.shape[0]], - dtype=weight.dtype, - device=self._device), + lora_a=torch.rand( + [weight.shape[1], rank], dtype=weight.dtype, device=self._device + ), + lora_b=torch.rand( + [rank, weight.shape[0]], dtype=weight.dtype, device=self._device + ), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand( @@ -143,27 +142,26 @@ def generate_data( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, ).to(device) total_tokens = seq_len_tensor.sum() if op_type == "shrink": - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) lora_weights = torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, ).to(device) # shrink op need atomic_add, so output is initinized by 0 - ref_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=dtype, - device=inputs_tensor.device) + ref_out_tensor = torch.zeros( + (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device + ) # NOTE shrink kernel using torch.float32 as output type - our_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=torch.float32).to(device) + our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to( + device + ) else: inputs_tensor = torch.rand( (total_tokens, max_rank), @@ -181,15 +179,16 @@ def generate_data( ).to(device) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )).to(device) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ).to(device) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]].copy_(lora_index) + indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_( + lora_index + ) current_offset += seq_len_tensor[b_id].item() return PunicaTensors( @@ -214,8 +213,7 @@ def generate_data_for_expand_nslices( nslices, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -231,22 +229,25 @@ def generate_data_for_expand_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to( + device + ) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) @@ -273,8 +274,7 @@ def generate_data_for_nslices( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -283,9 +283,7 @@ def generate_data_for_nslices( lora_weights_lst = [] if op_type == "shrink": - - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) for _ in range(nslices): if op_type == "shrink": @@ -293,7 +291,8 @@ def generate_data_for_nslices( torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # NOTE shrink kernel using torch.float32 as output type # shrink op need atomic_add, so output is initinized by 0 our_out_tensor = torch.zeros( @@ -310,23 +309,26 @@ def generate_data_for_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - our_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + our_out_tensor = torch.rand( + (total_tokens, hidden_size * nslices), dtype=dtype + ).to(device) # Ensure the same input. ref_out_tensor = our_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 7bb5d8980d61..e75fac0481be 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -22,7 +22,7 @@ def use_v0_only(monkeypatch): """ This module tests V0 internals, so set VLLM_USE_V1=0. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") MODELS = [ @@ -40,29 +40,28 @@ def test_metric_counter_prompt_tokens( dtype: str, max_tokens: int, ) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: + with vllm_runner( + model, dtype=dtype, disable_log_stats=False, gpu_memory_utilization=0.4 + ) as vllm_model: tokenizer = vllm_model.model.get_tokenizer() - prompt_token_counts = [ - len(tokenizer.encode(p)) for p in example_prompts - ] + prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts] # This test needs at least 2 prompts in a batch of different lengths to # verify their token count is correct despite padding. assert len(example_prompts) > 1, "at least 2 prompts are required" assert prompt_token_counts[0] != prompt_token_counts[1], ( - "prompts of different lengths are required") + "prompts of different lengths are required" + ) vllm_prompt_token_count = sum(prompt_token_counts) _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.model.llm_engine.stat_loggers["prometheus"] metric_count = stat_logger.metrics.counter_prompt_tokens.labels( - **stat_logger.labels)._value.get() + **stat_logger.labels + )._value.get() assert vllm_prompt_token_count == metric_count, ( - f"prompt token count: {vllm_prompt_token_count!r}\n" - f"metric: {metric_count!r}") + f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}" + ) @pytest.mark.parametrize("model", MODELS) @@ -75,15 +74,15 @@ def test_metric_counter_generation_tokens( dtype: str, max_tokens: int, ) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: + with vllm_runner( + model, dtype=dtype, disable_log_stats=False, gpu_memory_utilization=0.4 + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.model.llm_engine.stat_loggers["prometheus"] metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() + **stat_logger.labels + )._value.get() vllm_generation_count = 0 for i in range(len(example_prompts)): vllm_output_ids, vllm_output_str = vllm_outputs[i] @@ -93,8 +92,8 @@ def test_metric_counter_generation_tokens( vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) assert vllm_generation_count == metric_count, ( - f"generation token count: {vllm_generation_count!r}\n" - f"metric: {metric_count!r}") + f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}" + ) @pytest.mark.parametrize("model", MODELS) @@ -109,17 +108,18 @@ def test_metric_counter_generation_tokens_multi_step( ) -> None: num_scheduler_steps = 8 with vllm_runner( - model, - disable_log_stats=False, - gpu_memory_utilization=0.4, - num_scheduler_steps=num_scheduler_steps, - disable_async_output_proc=disable_async_output_proc, + model, + disable_log_stats=False, + gpu_memory_utilization=0.4, + num_scheduler_steps=num_scheduler_steps, + disable_async_output_proc=disable_async_output_proc, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.model.llm_engine.stat_loggers["prometheus"] metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() + **stat_logger.labels + )._value.get() vllm_generation_count = 0 for i in range(len(example_prompts)): vllm_output_ids, vllm_output_str = vllm_outputs[i] @@ -130,25 +130,29 @@ def test_metric_counter_generation_tokens_multi_step( # The multi-step scheduling will continue to execute forward even when # encountering EOS, leading to slightly imprecise metrics. - assert abs(vllm_generation_count - metric_count) <\ - len(example_prompts) * num_scheduler_steps, \ - (f"generation token count: {vllm_generation_count!r}\n" - f"metric: {metric_count!r}") + assert ( + abs(vllm_generation_count - metric_count) + < len(example_prompts) * num_scheduler_steps + ), f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}" @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize( "served_model_name", - [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) -def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, - served_model_name: list[str]) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.3, - served_model_name=served_model_name) as vllm_model: - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]], +) +def test_metric_set_tag_model_name( + vllm_runner, model: str, dtype: str, served_model_name: list[str] +) -> None: + with vllm_runner( + model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.3, + served_model_name=served_model_name, + ) as vllm_model: + stat_logger = vllm_model.model.llm_engine.stat_loggers["prometheus"] metrics_tag_content = stat_logger.labels["model_name"] if envs.VLLM_CI_USE_S3: @@ -156,12 +160,14 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, if served_model_name is None or served_model_name == []: assert metrics_tag_content == model, ( f"Metrics tag model_name is wrong! expect: {model!r}\n" - f"actual: {metrics_tag_content!r}") + f"actual: {metrics_tag_content!r}" + ) else: assert metrics_tag_content == served_model_name[0], ( f"Metrics tag model_name is wrong! expect: " f"{served_model_name[0]!r}\n" - f"actual: {metrics_tag_content!r}") + f"actual: {metrics_tag_content!r}" + ) @pytest.mark.parametrize("model", MODELS) @@ -197,8 +203,7 @@ async def test_async_engine_log_metrics_regression( async for _ in results: pass - assert_metrics(model, async_engine.engine, disable_log_stats, - len(example_prompts)) + assert_metrics(model, async_engine.engine, disable_log_stats, len(example_prompts)) @pytest.mark.parametrize("model", MODELS) @@ -245,18 +250,17 @@ def test_metric_spec_decode( k = 5 with vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_config={ - "model": model, - "num_speculative_tokens": k, - }, + model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_config={ + "model": model, + "num_speculative_tokens": k, + }, ) as vllm_model: - # Force log interval to be 0 to catch all metrics. - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.model.llm_engine.stat_loggers["prometheus"] stat_logger.local_interval = 0 # Note that the purpose of this test is to verify spec decode @@ -267,8 +271,7 @@ def test_metric_spec_decode( "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, "counter_spec_decode_num_draft_tokens": lambda v: v == k, - "counter_spec_decode_num_emitted_tokens": - lambda v: 0 <= v <= k + 1, + "counter_spec_decode_num_emitted_tokens": lambda v: 0 <= v <= k + 1, } # Use one request to better inspect the metrics. @@ -276,12 +279,15 @@ def test_metric_spec_decode( _ = vllm_model.generate_greedy(prompts, max_tokens) for metric_name, is_expected in metric_name_to_expected_fn.items(): - metric_val = getattr( - stat_logger.metrics, - metric_name).labels(**stat_logger.labels)._value.get() + metric_val = ( + getattr(stat_logger.metrics, metric_name) + .labels(**stat_logger.labels) + ._value.get() + ) assert is_expected(metric_val), ( f"the value of metric {metric_name} ({metric_val}) " - "does not meet expectation") + "does not meet expectation" + ) @pytest.mark.parametrize("model", MODELS) @@ -313,7 +319,6 @@ def test_metric_spec_decode_interval( engine = LLMEngine.from_engine_args(engine_args) try: - engine.add_request( "request-id-0", example_prompts[0], @@ -321,7 +326,7 @@ def test_metric_spec_decode_interval( ) # set log internal - stat_logger = engine.stat_loggers['prometheus'] + stat_logger = engine.stat_loggers["prometheus"] stat_logger.local_interval = log_interval # prefill @@ -358,35 +363,37 @@ def test_metric_spec_decode_interval( "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, "counter_spec_decode_num_draft_tokens": lambda v: v == k, - "counter_spec_decode_num_emitted_tokens": - lambda v: 0 <= v <= k + 1, + "counter_spec_decode_num_emitted_tokens": lambda v: 0 <= v <= k + 1, } for metric_name, is_expected in metric_name_to_expected_fn.items(): - metric_val = getattr( - stat_logger.metrics, - metric_name).labels(**stat_logger.labels)._value.get() + metric_val = ( + getattr(stat_logger.metrics, metric_name) + .labels(**stat_logger.labels) + ._value.get() + ) assert is_expected(metric_val), ( f"the value of metric {metric_name} ({metric_val}) " - "does not meet expectation") + "does not meet expectation" + ) finally: del engine cleanup_dist_env_and_memory() -def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, - num_requests: int) -> None: +def assert_metrics( + model: str, engine: LLMEngine, disable_log_stats: bool, num_requests: int +) -> None: if disable_log_stats: with pytest.raises(AttributeError): _ = engine.stat_loggers else: - assert (engine.stat_loggers - is not None), "engine.stat_loggers should be set" + assert engine.stat_loggers is not None, "engine.stat_loggers should be set" # Ensure the count bucket of request-level histogram metrics matches # the number of requests as a simple sanity check to ensure metrics are # generated - labels = {'model_name': model} + labels = {"model_name": model} request_histogram_metrics = [ "vllm:e2e_request_latency_seconds", "vllm:request_prompt_tokens", @@ -395,10 +402,8 @@ def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, "vllm:request_params_max_tokens", ] for metric_name in request_histogram_metrics: - metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", - labels) - assert ( - metric_value == num_requests), "Metrics should be collected" + metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", labels) + assert metric_value == num_requests, "Metrics should be collected" @pytest.mark.parametrize("model", MODELS) @@ -418,9 +423,7 @@ def test_engine_log_metrics_ray( # We have to run in a Ray task for Ray metrics to be emitted correctly @ray.remote(num_gpus=1) def _inner(): - class _RayPrometheusStatLogger(RayPrometheusStatLogger): - def __init__(self, *args, **kwargs): self._i = 0 super().__init__(*args, **kwargs) @@ -438,7 +441,8 @@ def log(self, *args, **kwargs): logger = _RayPrometheusStatLogger( local_interval=0.5, labels=dict(model_name=engine.model_config.served_model_name), - vllm_config=engine.vllm_config) + vllm_config=engine.vllm_config, + ) engine.add_logger("ray", logger) for i, prompt in enumerate(example_prompts): engine.add_request( diff --git a/tests/mistral_tool_use/conftest.py b/tests/mistral_tool_use/conftest.py index e89e60c5a02e..d476e709a8c5 100644 --- a/tests/mistral_tool_use/conftest.py +++ b/tests/mistral_tool_use/conftest.py @@ -17,8 +17,9 @@ def server_config(request): config = CONFIGS[request.param] if current_platform.is_rocm() and not config.get("supports_rocm", True): - pytest.skip("The {} model can't be tested on the ROCm platform".format( - config["model"])) + pytest.skip( + "The {} model can't be tested on the ROCm platform".format(config["model"]) + ) # download model and tokenizer using transformers snapshot_download(config["model"]) @@ -30,8 +31,9 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=480) as server: + with RemoteOpenAIServer( + model, ARGS + args_for_model, max_wait_seconds=480 + ) as server: yield server diff --git a/tests/mistral_tool_use/test_mistral_tool_calls.py b/tests/mistral_tool_use/test_mistral_tool_calls.py index 9bf6863f3f2b..3c4a543abe41 100644 --- a/tests/mistral_tool_use/test_mistral_tool_calls.py +++ b/tests/mistral_tool_use/test_mistral_tool_calls.py @@ -19,12 +19,12 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL], tool_choice=WEATHER_TOOL, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 1 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral diff --git a/tests/mistral_tool_use/utils.py b/tests/mistral_tool_use/utils.py index 7a026cd9bb61..13a234f8e26b 100644 --- a/tests/mistral_tool_use/utils.py +++ b/tests/mistral_tool_use/utils.py @@ -18,17 +18,16 @@ class ServerConfig(TypedDict, total=False): CONFIGS: dict[str, ServerConfig] = { "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--tokenizer-mode", "mistral", - "--ignore-patterns=\"consolidated.safetensors\"" + "--tokenizer-mode", + "mistral", + '--ignore-patterns="consolidated.safetensors"', ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, } diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py index c6d89d849e9f..7539c8990db6 100644 --- a/tests/model_executor/conftest.py +++ b/tests/model_executor/conftest.py @@ -6,8 +6,10 @@ @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) @pytest.fixture @@ -15,38 +17,25 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 140f00294765..0f2cb71fcd24 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -6,18 +6,31 @@ from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import (GeluAndMul, - ReLUSquaredActivation, - SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, - vllm_topk_softmax) +from vllm.model_executor.layers.activation import ( + GeluAndMul, + ReLUSquaredActivation, + SiluAndMul, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + dispatch_topk_func, + vllm_topk_softmax, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, +) from vllm.model_executor.layers.layernorm import ( - RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, - rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) + RMSNorm, + dispatch_cuda_rmsnorm_func, + fused_add_rms_norm, + rms_norm, + rocm_aiter_fused_add_rms_norm, + rocm_aiter_rms_norm, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) + cutlass_scaled_mm, + dispatch_w8a8_blockscale_func, + w8a8_block_fp8_matmul, +) from vllm.platforms import current_platform @@ -64,13 +77,22 @@ class Relu3(ReLUSquaredActivation): ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), - ]) -def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, - ops_enabled: list[int], default_on: bool): + ], +) +def test_enabled_ops( + env: str, + torch_level: int, + use_inductor: bool, + ops_enabled: list[int], + default_on: bool, +): vllm_config = VllmConfig( - compilation_config=CompilationConfig(use_inductor=bool(use_inductor), - level=torch_level, - custom_ops=env.split(","))) + compilation_config=CompilationConfig( + use_inductor=bool(use_inductor), + level=torch_level, + custom_ops=env.split(","), + ) + ) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on @@ -98,39 +120,49 @@ class SiluAndMul2(SiluAndMul): @pytest.mark.parametrize( - "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"] +) def test_enabled_ops_invalid(env: str): with pytest.raises(Exception): # noqa - vllm_config = VllmConfig(compilation_config=CompilationConfig( - custom_ops=env.split(","))) + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=env.split(",")) + ) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() @pytest.mark.skipif( not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ", +) @pytest.mark.parametrize("use_cutlass", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - +def test_w8a8_blockscale_dispatch( + use_cutlass: bool, + use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch, +): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) + monkeypatch.setenv( + "VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_gemm_w8a8_blockscale + ) - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) + use_aiter_and_is_supported = bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale) + ) block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported + ) if use_cutlass: assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) + elif ( + current_platform.is_rocm() + and int(use_rocm_aiter) + and int(use_rocm_aiter_gemm_w8a8_blockscale) + ): + assert block_scale_func == (torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) else: assert block_scale_func == w8a8_block_fp8_matmul @@ -142,7 +174,9 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax) + rocm_aiter_topk_softmax, + ) + assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @@ -151,22 +185,28 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, - use_rocm_aiter_norm: str, monkeypatch): +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" +) +def test_rms_norm_dispatch( + add_residual: bool, use_rocm_aiter: str, use_rocm_aiter_norm: str, monkeypatch +): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) if not add_residual: - if current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): + if ( + current_platform.is_rocm() + and int(use_rocm_aiter) + and int(use_rocm_aiter_norm) + ): assert rms_norm_func == rocm_aiter_rms_norm else: assert rms_norm_func == rms_norm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): + elif ( + current_platform.is_rocm() and int(use_rocm_aiter) and int(use_rocm_aiter_norm) + ): assert rms_norm_func == rocm_aiter_fused_add_rms_norm else: assert rms_norm_func == fused_add_rms_norm diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index f08c7f7efccb..d5b8de9b490c 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -11,15 +11,16 @@ from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor, - get_local_guided_decoding_logits_processor) + get_local_guided_decoding_logits_processor, +) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( - JSONLogitsProcessor, RegexLogitsProcessor) + JSONLogitsProcessor, + RegexLogitsProcessor, +) from vllm.sampling_params import GuidedDecodingParams -MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' -GUIDED_DECODING_BACKENDS = [ - "outlines", "lm-format-enforcer", "xgrammar", "guidance" -] +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar", "guidance"] GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"] REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" @@ -35,16 +36,12 @@ def deepseek_r1_qwen_tokenizer(): return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) -def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, - sample_json_schema): +def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" - regex_LP = RegexLogitsProcessor(sample_regex, - zephyr_7B_tokenzer, - reasoner=None) - json_LP = JSONLogitsProcessor(sample_json_schema, - zephyr_7B_tokenzer, - whitespace_pattern=None, - reasoner=None) + regex_LP = RegexLogitsProcessor(sample_regex, zephyr_7B_tokenzer, reasoner=None) + json_LP = JSONLogitsProcessor( + sample_json_schema, zephyr_7B_tokenzer, whitespace_pattern=None, reasoner=None + ) tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -62,11 +59,9 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, @pytest.mark.asyncio @pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("is_local", [True, False]) -async def test_guided_logits_processor_black_box(backend: str, is_local: bool, - sample_regex, - sample_json_schema, - zephyr_7B_tokenzer): - +async def test_guided_logits_processor_black_box( + backend: str, is_local: bool, sample_regex, sample_json_schema, zephyr_7B_tokenzer +): config = ModelConfig( MODEL_NAME, task="generate", @@ -78,10 +73,15 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, ) regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = get_local_guided_decoding_logits_processor( - regex_request, zephyr_7B_tokenzer, config) if is_local else \ - await get_guided_decoding_logits_processor( - regex_request, zephyr_7B_tokenzer, config) + regex_lp = ( + get_local_guided_decoding_logits_processor( + regex_request, zephyr_7B_tokenzer, config + ) + if is_local + else await get_guided_decoding_logits_processor( + regex_request, zephyr_7B_tokenzer, config + ) + ) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -90,10 +90,10 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - json_request = GuidedDecodingParams(json=sample_json_schema, - backend=backend) + json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, zephyr_7B_tokenzer, config) + json_request, zephyr_7B_tokenzer, config + ) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -103,14 +103,17 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, @pytest.mark.asyncio -@pytest.mark.parametrize("backend", - GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) +@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT) @pytest.mark.parametrize("is_local", [True, False]) @pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"]) async def test_guided_logits_processor_with_reasoning( - backend: str, is_local: bool, reasoning_backend: str, sample_regex, - sample_json_schema, deepseek_r1_qwen_tokenizer): - + backend: str, + is_local: bool, + reasoning_backend: str, + sample_regex, + sample_json_schema, + deepseek_r1_qwen_tokenizer, +): config = ModelConfig( REASONING_MODEL_NAME, task="generate", @@ -120,16 +123,18 @@ async def test_guided_logits_processor_with_reasoning( seed=0, dtype="bfloat16", ) - token_ids = deepseek_r1_qwen_tokenizer.encode( - "here is the thinking process") + token_ids = deepseek_r1_qwen_tokenizer.encode("here is the thinking process") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = get_local_guided_decoding_logits_processor(regex_request, - deepseek_r1_qwen_tokenizer, config, - reasoning_backend) if is_local else \ - await get_guided_decoding_logits_processor( - regex_request, deepseek_r1_qwen_tokenizer, config, - reasoning_backend) + regex_lp = ( + get_local_guided_decoding_logits_processor( + regex_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + if is_local + else await get_guided_decoding_logits_processor( + regex_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + ) assert regex_lp is not None tensor = torch.rand(151664) original_tensor = torch.clone(tensor) @@ -137,15 +142,17 @@ async def test_guided_logits_processor_with_reasoning( assert tensor.shape == original_tensor.shape assert torch.allclose(tensor, original_tensor) - token_ids = deepseek_r1_qwen_tokenizer.encode( - "here is the thinking process") - json_request = GuidedDecodingParams(json=sample_json_schema, - backend=backend) - json_lp = get_local_guided_decoding_logits_processor( - json_request, deepseek_r1_qwen_tokenizer, config, - reasoning_backend) if is_local else \ - await get_guided_decoding_logits_processor( - json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) + token_ids = deepseek_r1_qwen_tokenizer.encode("here is the thinking process") + json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) + json_lp = ( + get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + if is_local + else await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + ) assert json_lp is not None tensor = torch.rand(151664) original_tensor = torch.clone(tensor) @@ -155,14 +162,18 @@ async def test_guided_logits_processor_with_reasoning( # Thinking is over, so the tensor should change. token_ids = deepseek_r1_qwen_tokenizer.encode( - "here is the thinking process") - json_request = GuidedDecodingParams(json=sample_json_schema, - backend=backend) - json_lp = get_local_guided_decoding_logits_processor( - json_request, deepseek_r1_qwen_tokenizer, config, - reasoning_backend) if is_local else \ - await get_guided_decoding_logits_processor( - json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) + "here is the thinking process" + ) + json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) + json_lp = ( + get_local_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + if is_local + else await get_guided_decoding_logits_processor( + json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend + ) + ) assert json_lp is not None tensor = torch.rand(151664) original_tensor = torch.clone(tensor) @@ -172,20 +183,16 @@ async def test_guided_logits_processor_with_reasoning( def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): - with pytest.raises(ValueError, - match="You can only use one kind of guided"): + with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, regex=sample_regex) - with pytest.raises(ValueError, - match="You can only use one kind of guided"): + with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, json_object=True) - with pytest.raises(ValueError, - match="You can only use one kind of guided"): + with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"]) - with pytest.raises(ValueError, - match="You can only use one kind of guided"): + with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") @@ -193,8 +200,7 @@ def test_guided_decoding_backend_options(): """Test backend-specific options""" with pytest.warns(DeprecationWarning): guided_decoding_params = GuidedDecodingParams( - backend= - "xgrammar:no-fallback,disable-any-whitespace,no-additional-properties" + backend="xgrammar:no-fallback,disable-any-whitespace,no-additional-properties" ) assert guided_decoding_params.backend == "xgrammar" assert guided_decoding_params.disable_fallback @@ -208,12 +214,11 @@ def test_pickle_xgrammar_tokenizer_data(): except ImportError: pytest.skip("Could not import xgrammar to run test") - from vllm.model_executor.guided_decoding.xgrammar_decoding import ( - TokenizerData) + from vllm.model_executor.guided_decoding.xgrammar_decoding import TokenizerData + tokenizer_data = TokenizerData( - metadata= - '{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}', - encoded_vocab=['!', '"', '#', '$', '%'], + metadata='{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}', + encoded_vocab=["!", '"', "#", "$", "%"], ) pickled = pickle.dumps(tokenizer_data) @@ -222,5 +227,6 @@ def test_pickle_xgrammar_tokenizer_data(): depickled: TokenizerData = pickle.loads(pickled) assert depickled is not None - assert json.loads( - depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value + assert ( + json.loads(depickled.metadata)["vocab_type"] == xgr.VocabType.BYTE_LEVEL.value + ) diff --git a/tests/model_executor/test_logits_processor.py b/tests/model_executor/test_logits_processor.py index 532ebba038d3..af09cc7b4207 100644 --- a/tests/model_executor/test_logits_processor.py +++ b/tests/model_executor/test_logits_processor.py @@ -15,38 +15,36 @@ class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): + def __init__(self, vocab_size: int, scale: float, fake_logits: torch.Tensor): super().__init__(vocab_size=vocab_size, scale=scale) self.fake_logits = fake_logits.clone() def forward(self, *args, **kwargs): - with patch( + with ( + patch( "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( + lambda x, y: x, + ), + patch( "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): + lambda *args, **kwargs: self.fake_logits, + ), + ): return super().forward(*args, **kwargs) def _prepare_test( - batch_size: int + batch_size: int, ) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -72,10 +70,12 @@ def pick_ith(token_ids, logits): request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), + sampling_params=SamplingParams( + temperature=0, logits_processors=[pick_ith] + ), block_tables={0: [1]}, - )) + ) + ) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( @@ -83,16 +83,15 @@ def pick_ith(token_ids, logits): seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + lm_head=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata + ) assert torch.isinf(logits_processor_output[:, 0]).all() fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) + torch.testing.assert_close( + logits_processor_output[:, 1], fake_logits[:, 1], rtol=1e-4, atol=0.0 + ) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 4bdb651e5170..43fdacc4b6e6 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -14,23 +14,26 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") -MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", - "intfloat/multilingual-e5-base") +MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", "intfloat/multilingual-e5-base") REVISION_ROBERTA = os.environ.get("REVISION", "main") -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_model_loading_with_params(vllm_runner): """ Test parameter weight loading with tp>1. """ - with vllm_runner(model_name=MODEL_NAME, - revision=REVISION, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -57,18 +60,22 @@ def check_model(model): assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_roberta_model_loading_with_params(vllm_runner): """ Test parameter weight loading with tp>1. """ - with vllm_runner(model_name=MODEL_NAME_ROBERTA, - revision=REVISION_ROBERTA, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME_ROBERTA, + revision=REVISION_ROBERTA, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -95,18 +102,20 @@ def check_model(model): assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_facebook_roberta_model_loading_with_params(vllm_runner): """ Test loading roberta-base model with no lm_head. """ model_name = "FacebookAI/roberta-base" - with vllm_runner(model_name=model_name, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=model_name, dtype="float16", max_model_len=MAX_MODEL_LEN + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_tokenizer = vllm_model.model.llm_engine.tokenizer assert model_tokenizer.tokenizer_id == model_name diff --git a/tests/model_executor/test_weight_utils.py b/tests/model_executor/test_weight_utils.py index df625b8d6004..6dc120ddbac9 100644 --- a/tests/model_executor/test_weight_utils.py +++ b/tests/model_executor/test_weight_utils.py @@ -9,23 +9,24 @@ from huggingface_hub.utils import LocalEntryNotFoundError from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, enable_hf_transfer) + download_weights_from_hf, + enable_hf_transfer, +) def test_hf_transfer_auto_activation(): if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: # in case it is already set, we can't test the auto activation - pytest.skip( - "HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") + pytest.skip("HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") enable_hf_transfer() try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + HF_TRANSFER_ACTIVE = True except ImportError: HF_TRANSFER_ACTIVE = False - assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANSFER_ACTIVE) + assert huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == HF_TRANSFER_ACTIVE def test_download_weights_from_hf(): @@ -34,22 +35,30 @@ def test_download_weights_from_hf(): # if offline is set and model is not cached huggingface_hub.constants.HF_HUB_OFFLINE = True with pytest.raises(LocalEntryNotFoundError): - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # download the model huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # now it should work offline huggingface_hub.constants.HF_HUB_OFFLINE = True - assert download_weights_from_hf( - "facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) is not None + assert ( + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) + is not None + ) if __name__ == "__main__": diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py index b4c771840196..2c008bfb7507 100644 --- a/tests/models/language/generation/test_bart.py +++ b/tests/models/language/generation/test_bart.py @@ -7,8 +7,12 @@ from vllm.sequence import SampleLogprobs -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) +from ....conftest import ( + DecoderPromptType, + ExplicitEncoderDecoderPrompt, + HfRunner, + VllmRunner, +) from ....utils import multi_gpu_test from ...utils import check_logprobs_close @@ -40,7 +44,7 @@ def run_test( tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ) -> None: - ''' + """ Test the vLLM BART model for a variety of encoder/decoder input prompts, by validating it against HuggingFace (HF) BART. @@ -48,7 +52,7 @@ def run_test( * hf_runner: HuggingFace (HF) test model runner * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a + * example_encoder_decoder_prompts: test fixture which provides a dictionary of dummy prompts * model: the HF ID of the specific BART variant under test * dtype: the tensor datatype to employ @@ -59,45 +63,45 @@ def run_test( prompt scenarios to test A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - + specifically when the decoder prompt is None. + The HF GenerationMixin's default behavior is to force the first decoded token to be if the prompt does not already contain (this is accomplished using a logit processor setting.) - + So when we use HF BART as our baseline for comparison, note that when the user provides a request with a None decoder prompt (i.e. a singleton encoder prompt, or else an explicit encoder/ decoder prompt with the decoder sub-prompt set to None), HF and vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, + + * HF will (1) tokenize the None prompt as an empty token-list, (2) append to the beginning, yielding [], (3) pass this token list to the model, and then (4) after computing logits during prefill, override the model logits & force to be the first generated token. - + * vLLM will (1) tokenize the None prompt as [], (2) append decoder- start-token to the beginning, yielding [], (3) pass these tokens to the model & proceed with generation. - + The net effect is that compared to vLLM, the list of HF *decoded* tokens will contain one more initial than the vLLM generated tokens, because vLLM's token is injected into the prompt rather than into the generated output. This is in spite of the fact that overall, the complete sequences (prompt + decoded tokens) produced by vLLM will match HF. - + So when we use HF decoded token output to validate vLLM's decoded token output, the testing process must account for the difference in decoded token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - + decoder-prompt-is-None case. + One option is to disable the logit processor feature that forces the token to be decoded (forced_bos_token_id = None), eliminating the problem entirely. However this is not "normal" BART usage. - + The other option is - only in the decoder-prompt-is-None case - to discard the first decoded token from the HF output before comparing it to vLLM. @@ -105,7 +109,7 @@ def run_test( To that end, when testing the scenario where the decoder prompt is None (and only in that one scenario), this test skips the first HF decoded token during the process of validating the vLLM decoded output. - ''' + """ # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -122,13 +126,16 @@ def run_test( # decoder-only unit tests expect), so when testing an encoder/decoder # model we must explicitly specify enforce_eager=True in the VllmRunner # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + ) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) + prompts, max_tokens, num_logprobs + ) # Configuration settings for HF baseline hf_kwargs = { @@ -139,20 +146,18 @@ def run_test( "length_penalty": 1.0, "early_stopping": False, "no_repeat_ngram_size": None, - "min_length": 0 + "min_length": 0, } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( prompts, max_tokens, num_logprobs, **hf_kwargs, - )) + ) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) + hf_skip_tokens = 1 if decoder_prompt_type == DecoderPromptType.NONE else 0 check_logprobs_close( outputs_0_lst=hf_outputs, @@ -169,8 +174,9 @@ def run_test( @pytest.mark.parametrize( "model", [ - pytest.param("facebook/bart-base", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "facebook/bart-base", marks=[pytest.mark.core_model, pytest.mark.cpu_model] + ), pytest.param("facebook/bart-large-cnn"), ], ) @@ -178,9 +184,16 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - +def test_models( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model, + dtype, + max_tokens, + num_logprobs, + decoder_prompt_type, +) -> None: run_test( hf_runner, vllm_runner, @@ -201,11 +214,17 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) -def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: +def test_models_distributed( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, + model, + dtype, + max_tokens, + num_logprobs, + decoder_prompt_type, +) -> None: run_test( hf_runner, vllm_runner, diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index ea240d227889..816137e2b8e7 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -62,8 +62,7 @@ pytest.param( "openbmb/MiniCPM3-4B", # fused_moe not supported on CPU - marks=[pytest.mark.core_model, - large_gpu_mark(min_gb=32)], + marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)], ), pytest.param( "facebook/opt-125m", # opt @@ -92,16 +91,24 @@ pytest.param( "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], - ) - ]) + ), + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, - monkeypatch) -> None: - + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + num_logprobs: int, + use_rocm_aiter: bool, + monkeypatch, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") @@ -122,34 +129,37 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds - else None) + prompt_embeds: Optional[list[torch.Tensor]] = [] if use_prompt_embeds else None prompt_token_ids = [] for prompt in example_prompts: - token_ids = hf_model.tokenizer(prompt, - return_tensors="pt").input_ids.to( - hf_model.model.device) + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids.to( + hf_model.model.device + ) prompt_token_ids.append(token_ids) if prompt_embeds is not None: - prompt_embeds.append(hf_model.model.get_input_embeddings()( - token_ids).squeeze(0)) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0) + ) with vllm_runner( - model, - tokenizer_name=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - trust_remote_code=model_info.trust_remote_code, - max_num_seqs=2, - enable_prompt_embeds=use_prompt_embeds, + model, + tokenizer_name=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) if prompt_embeds is not None: vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( - prompt_embeds, max_tokens, num_logprobs) + prompt_embeds, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 5be4ae874e61..85b6f29b151c 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -11,17 +11,17 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner( - model, - load_format="dummy", + model, + load_format="dummy", ) as llm: if model == "google/gemma-3-4b-it": normalizers = llm.model.collective_rpc( - lambda self: self.model_runner.model.language_model.model. - normalizer.cpu().item()) + lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() + ) config = llm.model.llm_engine.model_config.hf_config.text_config else: normalizers = llm.model.collective_rpc( - lambda self: self.model_runner.model.model.normalizer.cpu( - ).item()) + lambda self: self.model_runner.model.model.normalizer.cpu().item() + ) config = llm.model.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/language/generation/test_granite.py b/tests/models/language/generation/test_granite.py index 2a39f78a708e..e569e75ff3a8 100644 --- a/tests/models/language/generation/test_granite.py +++ b/tests/models/language/generation/test_granite.py @@ -26,11 +26,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index eba14e64553e..5ca37df49051 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -77,7 +77,6 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -88,13 +87,15 @@ def test_models( with hf_runner(model) as hf_model: if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) else: hf_outputs = None with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) if model in V1_SUPPORTED_MODELS: with monkeypatch.context() as m: @@ -102,12 +103,15 @@ def test_models( if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - enforce_eager=True, - enable_prefix_caching=False) as vllm_model: + with vllm_runner( + model, + max_num_seqs=MAX_NUM_SEQS, + enforce_eager=True, + enable_prefix_caching=False, + ) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) else: vllm_v1_outputs = None @@ -139,7 +143,6 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -150,13 +153,14 @@ def test_batching( for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: - single_output, = vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs) + (single_output,) = vllm_model.generate_greedy_logprobs( + [prompt], max_tokens, num_logprobs + ) for_loop_outputs.append(single_output) batched_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=for_loop_outputs, @@ -181,18 +185,22 @@ def test_chunked_prefill( max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, num_logprobs) + with vllm_runner( + model, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: + with vllm_runner( + model, enable_chunked_prefill=False, max_num_seqs=max_num_seqs + ) as vllm_model: non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=chunked, @@ -211,8 +219,8 @@ def test_chunked_prefill_with_parallel_sampling( max_tokens: int, ) -> None: """ - Tests chunked prefill in conjunction with n > 1. - + Tests chunked prefill in conjunction with n > 1. + In this case, prefill is populated with decoding tokens and we test that it doesn't fail. @@ -220,16 +228,13 @@ def test_chunked_prefill_with_parallel_sampling( decoding steps inside a chunked prefill forward pass (where we have both prefill and decode together) """ - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) + sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens) with vllm_runner( - model, - enable_chunked_prefill=True, - # forces prefill chunks with decoding - max_num_batched_tokens=MAX_NUM_SEQS * 3, - max_num_seqs=MAX_NUM_SEQS, + model, + enable_chunked_prefill=True, + # forces prefill chunks with decoding + max_num_batched_tokens=MAX_NUM_SEQS * 3, + max_num_seqs=MAX_NUM_SEQS, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -247,10 +252,8 @@ def test_mamba_cache_cg_padding( batch size. If it's not, a torch RuntimeError will be raised because tensor dimensions aren't compatible. """ - vllm_config = EngineArgs(model=model, - trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph( - len(example_prompts)): + vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): example_prompts.append(example_prompts[0]) try: @@ -260,7 +263,8 @@ def test_mamba_cache_cg_padding( pytest.fail( "Couldn't run batch size which is not equal to a Cuda Graph " "captured batch size. " - "Could be related to mamba cache not padded correctly") + "Could be related to mamba cache not padded correctly" + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -277,8 +281,7 @@ def test_models_preemption_recompute( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: scheduler = vllm_model.model.llm_engine.scheduler[0] scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + preempt_vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) scheduler.ENABLE_ARTIFICIAL_PREEMPT = False vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) @@ -310,8 +313,10 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") + pytest.fail( + "Hybrid inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily " + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -320,10 +325,10 @@ def test_state_cleanup( example_prompts, model: str, ) -> None: - """ + """ This test is for verifying that the Hybrid state is cleaned up between steps. - + If its not cleaned, an error would be expected. """ try: @@ -331,8 +336,10 @@ def test_state_cleanup( for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") + pytest.fail( + "Hybrid inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids" + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -343,15 +350,13 @@ def test_multistep_correctness( model: str, max_tokens: int, ) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_outputs_multistep = vllm_model.generate_greedy( - example_prompts, max_tokens) + with vllm_runner(model, num_scheduler_steps=8, max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, num_scheduler_steps=1, max_num_seqs=2) as vllm_model: vllm_outputs_single_step = vllm_model.generate_greedy( - example_prompts, max_tokens) + example_prompts, max_tokens + ) check_outputs_equal( outputs_0_lst=vllm_outputs_multistep, @@ -372,15 +377,15 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, tensor_parallel_size=1, max_num_seqs=2) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, tensor_parallel_size=2, max_num_seqs=2) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=vllm_outputs_tp_1, diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index c70698ede37a..c477789628b4 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -8,7 +8,9 @@ import pytest from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( - MistralToolCall, MistralToolParser) + MistralToolCall, + MistralToolParser, +) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -35,136 +37,114 @@ ] # for function calling -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, }, -}, { - "type": "function", - "function": { - "name": "rewrite", - "description": "Rewrites text", - "parameters": { - "type": "object", - "required": [], - "properties": { - "text": { - "type": "string", - "description": "The input text to rewrite." - } - } - } - } -}] -MSGS = [ { - "role": "system", - "content": "You are an assistant." + "type": "function", + "function": { + "name": "rewrite", + "description": "Rewrites text", + "parameters": { + "type": "object", + "required": [], + "properties": { + "text": { + "type": "string", + "description": "The input text to rewrite.", + } + }, + }, + }, }, +] +MSGS = [ + {"role": "system", "content": "You are an assistant."}, { - "role": - "user", - "content": - "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa + "role": "user", + "content": "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors.", # noqa }, { - "role": - "assistant", - "content": - "", - "tool_calls": [{ - "id": "bbc5b7ede", - "type": "function", - "function": { - "name": - "rewrite", - "arguments": - '{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "bbc5b7ede", + "type": "function", + "function": { + "name": "rewrite", + "arguments": '{"text":"My English needs improvving, maybe I make errors."}', # noqa + }, } - }] + ], }, { "role": "tool", - "content": - "{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa + "content": '{"action":"rewrite","outcome":"My English needs improving, maybe I make errors."}', # noqa "tool_call_id": "bbc5b7ede", - "name": "rewrite" + "name": "rewrite", }, { "role": "assistant", - "content": "---\n\nMy English needs improving, maybe I make errors" + "content": "---\n\nMy English needs improving, maybe I make errors", }, { - "role": - "user", - "content": ("Can you tell me what the temperate" - " will be in Dallas, in fahrenheit?") - } + "role": "user", + "content": ( + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + ), + }, ] SAMPLE_JSON_SCHEMA = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -172,17 +152,25 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: + with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -196,27 +184,35 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", ) as mistral_format_model: mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="auto", - load_format="safetensors", - config_format="hf", + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", ) as hf_format_model: hf_format_outputs = hf_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_format_outputs, @@ -228,34 +224,35 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages(vllm_runner, model: str, - dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - max_model_len=8192, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: +def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None: + with vllm_runner( + model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: for prompt in SYMBOLIC_LANG_PROMPTS: msg = {"role": "user", "content": prompt} - outputs = vllm_model.model.chat([msg], - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.model.chat([msg], sampling_params=SAMPLING_PARAMS) assert "�" not in outputs[0].outputs[0].text.strip() @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: - + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: msgs = copy.deepcopy(MSGS) - outputs = vllm_model.model.chat(msgs, - tools=TOOLS, - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.model.chat( + msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS + ) tokenizer = vllm_model.model.get_tokenizer() tool_parser = MistralToolParser(tokenizer) @@ -267,16 +264,18 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: assert parsed_message.tools_called assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) - assert parsed_message.tool_calls[ - 0].function.name == "get_current_weather" - assert parsed_message.tool_calls[ - 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa + assert parsed_message.tool_calls[0].function.name == "get_current_weather" + assert ( + parsed_message.tool_calls[0].function.arguments + == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' + ) # noqa assert parsed_message.content is None @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("guided_backend", - ["outlines", "lm-format-enforcer", "xgrammar"]) +@pytest.mark.parametrize( + "guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"] +) def test_mistral_guided_decoding( monkeypatch: pytest.MonkeyPatch, vllm_runner, @@ -288,26 +287,24 @@ def test_mistral_guided_decoding( m.setenv("VLLM_USE_V1", "0") with vllm_runner( - model, - dtype='bfloat16', - tokenizer_mode="mistral", - guided_decoding_backend=guided_backend, + model, + dtype="bfloat16", + tokenizer_mode="mistral", + guided_decoding_backend=guided_backend, ) as vllm_model: guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA) - params = SamplingParams(max_tokens=512, - temperature=0.7, - guided_decoding=guided_decoding) - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {SAMPLE_JSON_SCHEMA}" - }] + params = SamplingParams( + max_tokens=512, temperature=0.7, guided_decoding=guided_decoding + ) + + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}", + }, + ] outputs = vllm_model.model.chat(messages, sampling_params=params) generated_text = outputs[0].outputs[0].text @@ -315,8 +312,7 @@ def test_mistral_guided_decoding( assert outputs is not None try: - jsonschema.validate(instance=json_response, - schema=SAMPLE_JSON_SCHEMA) + jsonschema.validate(instance=json_response, schema=SAMPLE_JSON_SCHEMA) except jsonschema.exceptions.ValidationError: pytest.fail("Generated response is not valid with JSON schema") @@ -346,17 +342,10 @@ def get_vocab(): "city": "Dallas", "state": "TX", "unit": "fahrenheit", - "sub_dict": { - "foo": "bar", - "inner": { - "x": 1, - "y": 2 - } - }, + "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}}, } - model_output = ( - f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") + model_output = f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}" parsed = parser.extract_tool_calls(model_output, None) diff --git a/tests/models/language/generation/test_phimoe.py b/tests/models/language/generation/test_phimoe.py index 6c9cc2821c30..e640655784cc 100644 --- a/tests/models/language/generation/test_phimoe.py +++ b/tests/models/language/generation/test_phimoe.py @@ -15,62 +15,56 @@ def test_phimoe_routing_function(): from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { 0: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.1, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.1, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, }, 1: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.4, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, - } + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.4, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, + }, } ground_truth = { 0: { - "topk_weights": - torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + "topk_weights": torch.tensor( + [1.0, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([3, 2], dtype=torch.long, requires_grad=False), }, 1: { - "topk_weights": - torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([0, 3], dtype=torch.long, requires_grad=False), - } + "topk_weights": torch.tensor( + [0.5, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + }, } for test_id in test_case: topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) - assert torch.allclose(topk_weights, - ground_truth[test_id]["topk_weights"]) + assert torch.allclose(topk_weights, ground_truth[test_id]["topk_weights"]) assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="This test takes a lot time to run on CPU, " - "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif( + condition=current_platform.is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.", +) @large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -87,11 +81,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index a663679a9c7c..66cdb5dd2a20 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -6,8 +6,7 @@ import pytest from tests.conftest import HfRunner -from tests.models.utils import (EmbedModelInfo, check_embeddings_close, - matryoshka_fy) +from tests.models.utils import EmbedModelInfo, check_embeddings_close, matryoshka_fy def run_embedding_correctness_test( @@ -29,12 +28,14 @@ def run_embedding_correctness_test( ) -def correctness_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - example_prompts, - vllm_extra_kwargs=None, - hf_model_callback=None): +def correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, + vllm_extra_kwargs=None, + hf_model_callback=None, +): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -51,18 +52,16 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype - with vllm_runner(model_info.name, - task="embed", - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: + with vllm_runner( + model_info.name, task="embed", max_model_len=None, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( - model_info.name, - dtype="float32", - is_sentence_transformer=True, + model_info.name, + dtype="float32", + is_sentence_transformer=True, ) as hf_model: - if hf_model_callback is not None: hf_model_callback(hf_model) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 6c4fde5fdfa9..65c82f9a03c4 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -27,7 +27,6 @@ class VllmMtebEncoder(mteb.Encoder): - def __init__(self, vllm_model): super().__init__() self.model = vllm_model @@ -50,8 +49,7 @@ def encode( def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -61,17 +59,15 @@ def predict( queries = [s[0] for s in sentences] corpus = [s[1] for s in sentences] - outputs = self.model.score(queries, - corpus, - truncate_prompt_tokens=-1, - use_tqdm=False) + outputs = self.model.score( + queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False + ) scores = np.array(outputs) scores = scores[np.argsort(r)] return scores class OpenAIClientMtebEncoder(mteb.Encoder): - def __init__(self, model_name: str, client): super().__init__() self.model_name = model_name @@ -84,8 +80,9 @@ def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: r = self.rng.permutation(len(sentences)) sentences = [sentences[i] for i in r] - embeddings = self.client.embeddings.create(model=self.model_name, - input=sentences) + embeddings = self.client.embeddings.create( + model=self.model_name, input=sentences + ) outputs = [d.embedding for d in embeddings.data] embeds = np.array(outputs) embeds = embeds[np.argsort(r)] @@ -93,7 +90,6 @@ def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: class ScoreClientMtebEncoder(mteb.Encoder): - def __init__(self, model_name: str, url): super().__init__() self.model_name = model_name @@ -102,8 +98,7 @@ def __init__(self, model_name: str, url): def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -119,27 +114,30 @@ def predict( return scores def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "text_1": query, - "text_2": corpus, - "truncate_prompt_tokens": -1, - }).json() - return response['data'][0]["score"] + response = requests.post( + self.url, + json={ + "model": self.model_name, + "text_1": query, + "text_2": corpus, + "truncate_prompt_tokens": -1, + }, + ).json() + return response["data"][0]["score"] class RerankClientMtebEncoder(ScoreClientMtebEncoder): - def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "query": query, - "documents": [corpus], - "truncate_prompt_tokens": -1, - }).json() - return response['results'][0]["relevance_score"] + response = requests.post( + self.url, + json={ + "model": self.model_name, + "query": query, + "documents": [corpus], + "truncate_prompt_tokens": -1, + }, + ).json() + return response["results"][0]["relevance_score"] def run_mteb_embed_task(encoder, tasks): @@ -158,11 +156,13 @@ def run_mteb_embed_task(encoder, tasks): return main_score -def mteb_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None): +def mteb_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, +): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -171,23 +171,23 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype - with vllm_runner(model_info.name, - task="embed", - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: - + with vllm_runner( + model_info.name, task="embed", max_model_len=None, **vllm_extra_kwargs + ) as vllm_model: if model_info.architecture: - assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) + assert ( + model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures + ) - vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), - MTEB_EMBED_TASKS) + vllm_main_score = run_mteb_embed_task( + VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS + ) vllm_dtype = vllm_model.model.llm_engine.model_config.dtype - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype="float32") as hf_model: - + with hf_runner( + model_info.name, is_sentence_transformer=True, dtype="float32" + ) as hf_model: if hf_model_callback is not None: hf_model_callback(hf_model) @@ -226,8 +226,7 @@ def run_mteb_rerank(cross_encoder, tasks, languages): top_k=10, save_predictions=True, output_folder=f"{results_folder}/stage2", - previous_results= - f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", + previous_results=f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", encode_kwargs={"show_progress_bar": False}, ) main_score = results[0].scores["test"][0]["main_score"] @@ -235,14 +234,11 @@ def run_mteb_rerank(cross_encoder, tasks, languages): def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): - with hf_runner(model_name, is_cross_encoder=True, - dtype="float32") as hf_model: - + with hf_runner(model_name, is_cross_encoder=True, dtype="float32") as hf_model: original_predict = hf_model.predict def _predict( - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ): @@ -256,20 +252,22 @@ def _predict( if hf_model_callback is not None: hf_model_callback(hf_model) - st_main_score = run_mteb_rerank(hf_model, - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) + st_main_score = run_mteb_rerank( + hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS + ) st_dtype = next(hf_model.model.model.parameters()).dtype return st_main_score, st_dtype -def mteb_test_rerank_models(hf_runner, - vllm_runner, - model_info: RerankModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None, - vllm_mteb_encoder=VllmMtebEncoder, - atol=MTEB_RERANK_TOL): +def mteb_test_rerank_models( + hf_runner, + vllm_runner, + model_info: RerankModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder, + atol=MTEB_RERANK_TOL, +): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -278,25 +276,29 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype - with vllm_runner(model_info.name, - task="score", - max_model_len=None, - max_num_seqs=8, - **vllm_extra_kwargs) as vllm_model: - + with vllm_runner( + model_info.name, + task="score", + max_model_len=None, + max_num_seqs=8, + **vllm_extra_kwargs, + ) as vllm_model: model_config = vllm_model.model.llm_engine.model_config if model_info.architecture: - assert (model_info.architecture in model_config.architectures) + assert model_info.architecture in model_config.architectures assert model_config.hf_config.num_labels == 1 - vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank( + vllm_mteb_encoder(vllm_model), + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS, + ) vllm_dtype = model_config.dtype st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, hf_model_callback) + hf_runner, model_info.name, hf_model_callback + ) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 64a8f25220da..9859a8b197db 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -8,85 +8,75 @@ MODELS = [ ########## BertModel - EmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), + EmbedModelInfo("BAAI/bge-base-en", architecture="BertModel", enable_test=True), + EmbedModelInfo("BAAI/bge-base-zh", architecture="BertModel", enable_test=False), + EmbedModelInfo("BAAI/bge-small-en", architecture="BertModel", enable_test=False), + EmbedModelInfo("BAAI/bge-small-zh", architecture="BertModel", enable_test=False), + EmbedModelInfo("BAAI/bge-large-en", architecture="BertModel", enable_test=False), + EmbedModelInfo("BAAI/bge-large-zh", architecture="BertModel", enable_test=False), + EmbedModelInfo( + "BAAI/bge-large-zh-noinstruct", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-base-en-v1.5", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-base-zh-v1.5", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-small-en-v1.5", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-small-zh-v1.5", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-large-en-v1.5", architecture="BertModel", enable_test=False + ), + EmbedModelInfo( + "BAAI/bge-large-zh-v1.5", architecture="BertModel", enable_test=False + ), ########## XLMRobertaModel - EmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - enable_test=True), + EmbedModelInfo("BAAI/bge-m3", architecture="XLMRobertaModel", enable_test=True), ########## Qwen2Model - EmbedModelInfo("BAAI/bge-code-v1", - architecture="Qwen2Model", - dtype="float32", - enable_test=True), + EmbedModelInfo( + "BAAI/bge-code-v1", architecture="Qwen2Model", dtype="float32", enable_test=True + ), ] RERANK_MODELS = [ ########## XLMRobertaForSequenceClassification - RerankModelInfo("BAAI/bge-reranker-base", - architecture="XLMRobertaForSequenceClassification", - enable_test=True), - RerankModelInfo("BAAI/bge-reranker-large", - architecture="XLMRobertaForSequenceClassification", - enable_test=False), - RerankModelInfo("BAAI/bge-reranker-v2-m3", - architecture="XLMRobertaForSequenceClassification", - enable_test=False) + RerankModelInfo( + "BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + enable_test=True, + ), + RerankModelInfo( + "BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False, + ), + RerankModelInfo( + "BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + enable_test=False, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 7fa9485dbc7f..972eb88d5d3e 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -8,45 +8,40 @@ from tests.conftest import HfRunner -from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, - mteb_test_rerank_models) +from .mteb_utils import RerankModelInfo, VllmMtebEncoder, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + RerankModelInfo( + "BAAI/bge-reranker-v2-gemma", architecture="GemmaForSequenceClassification" + ), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 class GemmaRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") @torch.no_grad() - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def get_inputs(pairs, tokenizer, prompt=None): if prompt is None: prompt = PROMPT sep = "\n" - prompt_inputs = tokenizer(prompt, - return_tensors=None, - add_special_tokens=False)["input_ids"] - sep_inputs = tokenizer(sep, - return_tensors=None, - add_special_tokens=False)["input_ids"] + prompt_inputs = tokenizer( + prompt, return_tensors=None, add_special_tokens=False + )["input_ids"] + sep_inputs = tokenizer(sep, return_tensors=None, add_special_tokens=False)[ + "input_ids" + ] inputs = [] for query, passage in pairs: query_inputs = tokenizer( @@ -70,8 +65,7 @@ def get_inputs(pairs, tokenizer, prompt=None): return_token_type_ids=False, add_special_tokens=False, ) - item["input_ids"] = item[ - "input_ids"] + sep_inputs + prompt_inputs + item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs item["attention_mask"] = [1] * len(item["input_ids"]) inputs.append(item) return tokenizer.pad( @@ -87,14 +81,19 @@ def get_inputs(pairs, tokenizer, prompt=None): inputs = inputs.to(self.model.device) _n_tokens = inputs["input_ids"].shape[1] logits = self.model(**inputs, return_dict=True).logits - _scores = (logits[:, -1, - self.yes_loc].view(-1, ).float().sigmoid()) + _scores = ( + logits[:, -1, self.yes_loc] + .view( + -1, + ) + .float() + .sigmoid() + ) scores.append(_scores[0].item()) return torch.Tensor(scores) class GemmaMtebEncoder(VllmMtebEncoder): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.prompt = PROMPT @@ -103,12 +102,10 @@ def __init__(self, *args, **kwargs): def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: - _sentences = [] for query, corpus, prompt in sentences: query = self.query_template.format(query=query) @@ -119,8 +116,9 @@ def predict( @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: +def test_rerank_models_mteb( + vllm_runner, model_info: RerankModelInfo, monkeypatch +) -> None: monkeypatch.setenv("VLLM_USE_V1", "0") assert model_info.architecture == "GemmaForSequenceClassification" @@ -133,8 +131,10 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, } } - mteb_test_rerank_models(GemmaRerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - vllm_mteb_encoder=GemmaMtebEncoder) + mteb_test_rerank_models( + GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + vllm_mteb_encoder=GemmaMtebEncoder, + ) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 77df6d16a367..23cb01356938 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -18,12 +18,13 @@ @pytest.mark.parametrize( "model", [ - pytest.param("jason9693/Qwen2.5-1.5B-apeach", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "jason9693/Qwen2.5-1.5B-apeach", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) -@pytest.mark.parametrize("dtype", - ["half"] if current_platform.is_rocm() else ["float"]) +@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"]) def test_models( hf_runner, vllm_runner, @@ -40,9 +41,9 @@ def test_models( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) # check logits difference @@ -53,5 +54,6 @@ def test_models( # the tolerance value of 1e-2 is selected based on the # half datatype tests in # tests/models/language/pooling/test_embedding.py - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py index 9a33063d7b46..c47c9c903b2a 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling/test_cross_encoder.py @@ -5,14 +5,19 @@ from .mteb_utils import RerankModelInfo, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - architecture="BertForSequenceClassification"), - RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", - architecture="Qwen3ForSequenceClassification") + RerankModelInfo( + "cross-encoder/ms-marco-TinyBERT-L-2-v2", + architecture="BertForSequenceClassification", + ), + RerankModelInfo( + "tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + architecture="Qwen3ForSequenceClassification", + ), ] @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index cc9e4102d5b7..cbd0f2d4efbd 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -26,35 +26,40 @@ def v1(run_with_both_engines): # case won't pass because gte-Qwen2-1.5B-instruct will cache custom # model code with bidirectional attention. # [Decoder-only] - pytest.param("BAAI/bge-multilingual-gemma2", - marks=[pytest.mark.core_model]), + pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), pytest.param( "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window - marks=[pytest.mark.core_model]), + marks=[pytest.mark.core_model], + ), # the qwen models interfere with each other (see PR # https://github.com/vllm-project/vllm/pull/18720). # To avoid this problem, for now we skip v0 since it will be # deprecated anyway. - pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", - marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), + pytest.param( + "ssmits/Qwen2-7B-Instruct-embed-base", + marks=[pytest.mark.skip_v0, pytest.mark.cpu_model], + ), # [Encoder-only] pytest.param( "BAAI/bge-base-en-v1.5", marks=[ # CPU only supports V1 pytest.mark.core_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), + pytest.mark.skip_v1, + ], + ), + pytest.param( + "sentence-transformers/all-MiniLM-L12-v2", marks=[pytest.mark.skip_v1] + ), + pytest.param("intfloat/multilingual-e5-small", marks=[pytest.mark.skip_v1]), + pytest.param( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", marks=[pytest.mark.skip_v1] + ), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param( + "sentence-transformers/stsb-roberta-base-v2", marks=[pytest.mark.skip_v1] + ), ], ) def test_models( @@ -71,13 +76,14 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": - vllm_extra_kwargs["override_pooler_config"] = \ - PoolerConfig(pooling_type="MEAN", normalize=False) + vllm_extra_kwargs["override_pooler_config"] = PoolerConfig( + pooling_type="MEAN", normalize=False + ) max_model_len: Optional[int] = 512 if model in [ - "sentence-transformers/all-MiniLM-L12-v2", - "sentence-transformers/stsb-roberta-base-v2" + "sentence-transformers/all-MiniLM-L12-v2", + "sentence-transformers/stsb-roberta-base-v2", ]: max_model_len = None @@ -92,10 +98,9 @@ def test_models( with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, - task="embed", - max_model_len=max_model_len, - **vllm_extra_kwargs) as vllm_model: + with vllm_runner( + model, task="embed", max_model_len=max_model_len, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index c2f70bb647a4..26a680b81325 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -15,8 +15,9 @@ from ....utils import RemoteOpenAIServer # GritLM embedding implementation is only supported by XFormers backend. -pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") +pytestmark = pytest.mark.skipif( + not importlib.util.find_spec("xformers"), reason="GritLM requires XFormers" +) MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 @@ -76,8 +77,9 @@ async def run_client_embeddings( def gritlm_instruction(instruction): - return ("<|user|>\n" + instruction + - "\n<|embed|>\n" if instruction else "<|embed|>\n") + return ( + "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" + ) def get_test_data(): @@ -86,7 +88,8 @@ def get_test_data(): README.md in https://github.com/ContextualAI/gritlm """ q_instruction = gritlm_instruction( - "Given a scientific paper title, retrieve the paper's abstract", ) + "Given a scientific paper title, retrieve the paper's abstract", + ) queries = [ "Bitcoin: A Peer-to-Peer Electronic Cash System", "Generative Representational Instruction Tuning", @@ -120,9 +123,9 @@ def test_gritlm_offline_embedding(vllm_runner): queries, q_instruction, documents, d_instruction = get_test_data() with vllm_runner( - MODEL_NAME, - task="embed", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + task="embed", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.model @@ -167,9 +170,9 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" with vllm_runner( - MODEL_NAME, - task="generate", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + task="generate", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.model diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 0ad54785308e..58cf44dda226 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -9,61 +9,65 @@ MODELS = [ ########## BertModel - EmbedModelInfo("thenlper/gte-large", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), + EmbedModelInfo("thenlper/gte-large", architecture="BertModel", enable_test=True), + EmbedModelInfo("thenlper/gte-base", architecture="BertModel", enable_test=False), + EmbedModelInfo("thenlper/gte-small", architecture="BertModel", enable_test=False), + EmbedModelInfo( + "thenlper/gte-large-zh", architecture="BertModel", enable_test=False + ), + EmbedModelInfo("thenlper/gte-base-zh", architecture="BertModel", enable_test=False), + EmbedModelInfo( + "thenlper/gte-small-zh", architecture="BertModel", enable_test=False + ), ########### NewModel - EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - enable_test=True), + EmbedModelInfo( + "Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True, + ), + EmbedModelInfo( + "Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", enable_test=True + ), + EmbedModelInfo( + "Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", enable_test=True + ), ########### Qwen2ForCausalLM - EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - architecture="Qwen2ForCausalLM", - enable_test=True), + EmbedModelInfo( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True, + ), ########## ModernBertModel - EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - architecture="ModernBertModel", - enable_test=True), + EmbedModelInfo( + "Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True, + ), ########## Qwen3ForCausalLM - EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=True), - EmbedModelInfo("Qwen/Qwen3-Embedding-4B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=False), + EmbedModelInfo( + "Qwen/Qwen3-Embedding-0.6B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True, + ), + EmbedModelInfo( + "Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False, + ), ] V1FlashAttentionImpNotSupported = [ - "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base" + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + "Alibaba-NLP/gte-modernbert-base", ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, - monkeypatch) -> None: +def test_embed_models_mteb( + hf_runner, vllm_runner, model_info: EmbedModelInfo, monkeypatch +) -> None: if model_info.name in V1FlashAttentionImpNotSupported: monkeypatch.setenv("VLLM_USE_V1", "0") @@ -71,14 +75,13 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info, vllm_extra_kwargs) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, example_prompts, - monkeypatch) -> None: +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts, monkeypatch +) -> None: if model_info.name in V1FlashAttentionImpNotSupported: monkeypatch.setenv("VLLM_USE_V1", "0") @@ -86,5 +89,6 @@ def test_embed_models_correctness(hf_runner, vllm_runner, if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + correctness_test_embed_models( + hf_runner, vllm_runner, model_info, example_prompts, vllm_extra_kwargs + ) diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py index d899aaada262..ab135c4540b7 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling/test_intfloat.py @@ -8,40 +8,38 @@ MODELS = [ ########## BertModel - EmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), + EmbedModelInfo("intfloat/e5-small", architecture="BertModel", enable_test=True), + EmbedModelInfo("intfloat/e5-base", architecture="BertModel", enable_test=False), + EmbedModelInfo("intfloat/e5-large", architecture="BertModel", enable_test=False), + EmbedModelInfo( + "intfloat/multilingual-e5-small", architecture="BertModel", enable_test=False + ), ########## XLMRobertaModel - EmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), + EmbedModelInfo( + "intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + enable_test=True, + ), + EmbedModelInfo( + "intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False, + ), + EmbedModelInfo( + "intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 9bfe7411e16b..9d6b21b1a3b8 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -7,53 +7,57 @@ from vllm import PoolingParams from ...utils import EmbedModelInfo, RerankModelInfo -from .embed_utils import (check_embeddings_close, - correctness_test_embed_models, matryoshka_fy) +from .embed_utils import ( + check_embeddings_close, + correctness_test_embed_models, + matryoshka_fy, +) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - EmbedModelInfo("jinaai/jina-embeddings-v3", - architecture="XLMRobertaModel", - is_matryoshka=True) + EmbedModelInfo( + "jinaai/jina-embeddings-v3", architecture="XLMRobertaModel", is_matryoshka=True + ) ] RERANK_MODELS = [ - RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification") + RerankModelInfo( + "jinaai/jina-reranker-v2-base-multilingual", + architecture="XLMRobertaForSequenceClassification", + ) ] @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - mteb_test_embed_models(hf_runner, - vllm_runner, - model_info, - hf_model_callback=hf_model_callback) + mteb_test_embed_models( + hf_runner, vllm_runner, model_info, hf_model_callback=hf_model_callback + ) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - correctness_test_embed_models(hf_runner, - vllm_runner, - model_info, - example_prompts, - hf_model_callback=hf_model_callback) + correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info, + example_prompts, + hf_model_callback=hf_model_callback, + ) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) @@ -76,32 +80,32 @@ def test_matryoshka( example_prompts = [str(s).strip() for s in example_prompts] with hf_runner( - model_info.name, - dtype=dtype, - is_sentence_transformer=True, + model_info.name, + dtype=dtype, + is_sentence_transformer=True, ) as hf_model: hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = matryoshka_fy(hf_outputs, dimensions) - with vllm_runner(model_info.name, - task="embed", - dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner( + model_info.name, task="embed", dtype=dtype, max_model_len=None + ) as vllm_model: assert vllm_model.model.llm_engine.model_config.is_matryoshka matryoshka_dimensions = ( - vllm_model.model.llm_engine.model_config.matryoshka_dimensions) + vllm_model.model.llm_engine.model_config.matryoshka_dimensions + ) assert matryoshka_dimensions is not None if dimensions not in matryoshka_dimensions: with pytest.raises(ValueError): vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) else: vllm_outputs = vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index e74c58744dd2..6bd848699b21 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -10,43 +10,42 @@ from .mteb_utils import RerankModelInfo, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=True), - RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=False) + RerankModelInfo( + "mixedbread-ai/mxbai-rerank-base-v2", + architecture="Qwen2ForSequenceClassification", + enable_test=True, + ), + RerankModelInfo( + "mixedbread-ai/mxbai-rerank-large-v2", + architecture="Qwen2ForSequenceClassification", + enable_test=False, + ), ] class MxbaiRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") self.no_loc = self.tokenizer.convert_tokens_to_ids("0") - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @@ -78,5 +77,6 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: "method": "from_2_way_softmax", } - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models( + MxbaiRerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs + ) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index e16ec239a338..e5840b77e606 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -7,30 +7,32 @@ from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - enable_test=True) + EmbedModelInfo( + "nomic-ai/nomic-embed-text-v1", architecture="NomicBertModel", enable_test=True + ), + EmbedModelInfo( + "nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False, + ), + EmbedModelInfo( + "nomic-ai/CodeRankEmbed", architecture="NomicBertModel", enable_test=False + ), + EmbedModelInfo( + "nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_nomic_max_model_len.py b/tests/models/language/pooling/test_nomic_max_model_len.py index 250b3a52835a..ce348785ec15 100644 --- a/tests/models/language/pooling/test_nomic_max_model_len.py +++ b/tests/models/language/pooling/test_nomic_max_model_len.py @@ -7,10 +7,10 @@ MODELS = [ EmbedModelInfo("nomic-ai/nomic-embed-text-v1"), - #EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), - #EmbedModelInfo("nomic-ai/CodeRankEmbed"), + # EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), + # EmbedModelInfo("nomic-ai/CodeRankEmbed"), EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe"), - #EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), + # EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), ] rope_theta = 1000 @@ -21,23 +21,20 @@ @pytest.mark.parametrize("model_info", MODELS) def test_default(model_info, vllm_runner): - with vllm_runner(model_info.name, task="embed", - max_model_len=None) as vllm_model: + with vllm_runner(model_info.name, task="embed", max_model_len=None) as vllm_model: model_config = vllm_model.model.llm_engine.model_config if model_info.name == "nomic-ai/nomic-embed-text-v2-moe": # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. assert model_config.max_model_len == 512 else: - assert ( - model_config.max_model_len == original_max_position_embeddings) + assert model_config.max_model_len == original_max_position_embeddings @pytest.mark.parametrize("model_info", MODELS) def test_set_max_model_len_legal(model_info, vllm_runner): # set max_model_len <= 512 - with vllm_runner(model_info.name, task="embed", - max_model_len=256) as vllm_model: + with vllm_runner(model_info.name, task="embed", max_model_len=256) as vllm_model: model_config = vllm_model.model.llm_engine.model_config assert model_config.max_model_len == 256 @@ -46,12 +43,12 @@ def test_set_max_model_len_legal(model_info, vllm_runner): # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. with pytest.raises(ValueError): - with vllm_runner(model_info.name, task="embed", - max_model_len=1024): + with vllm_runner(model_info.name, task="embed", max_model_len=1024): pass else: - with vllm_runner(model_info.name, task="embed", - max_model_len=1024) as vllm_model: + with vllm_runner( + model_info.name, task="embed", max_model_len=1024 + ) as vllm_model: model_config = vllm_model.model.llm_engine.model_config assert model_config.max_model_len == 1024 @@ -66,10 +63,9 @@ def test_set_max_model_len_illegal(model_info, vllm_runner): # set max_model_len > 2048 by hf_overrides hf_overrides = {"max_model_len": 4096} with pytest.raises(ValueError): - with vllm_runner(model_info.name, - task="embed", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, task="embed", max_model_len=None, hf_overrides=hf_overrides + ): pass @@ -80,16 +76,14 @@ def test_use_rope_scaling_legal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + "max_model_len": max_model_len, } - with vllm_runner(model_info.name, - task="embed", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, task="embed", max_model_len=None, hf_overrides=hf_overrides + ): pass @@ -100,16 +94,17 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings - } + "original_max_position_embeddings": original_max_position_embeddings, + }, } # illegal max_model_len with pytest.raises(ValueError): - with vllm_runner(model_info.name, - task="embed", - max_model_len=max_model_len + 1, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + task="embed", + max_model_len=max_model_len + 1, + hf_overrides=hf_overrides, + ): pass hf_overrides = { @@ -117,15 +112,13 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + 1 + "max_model_len": max_model_len + 1, } # illegal max_model_len by hf_overrides with pytest.raises(ValueError): - with vllm_runner(model_info.name, - task="embed", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, task="embed", max_model_len=None, hf_overrides=hf_overrides + ): pass diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 9c6a833b4138..36ef11b9b043 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -11,43 +11,42 @@ from .mteb_utils import RerankModelInfo, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("Qwen/Qwen3-Reranker-0.6B", - architecture="Qwen3ForSequenceClassification", - enable_test=True), - RerankModelInfo("Qwen/Qwen3-Reranker-4B", - architecture="Qwen3ForSequenceClassification", - enable_test=False) + RerankModelInfo( + "Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + enable_test=True, + ), + RerankModelInfo( + "Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + enable_test=False, + ), ] class Qwen3RerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @@ -72,7 +71,6 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { @@ -86,15 +84,14 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: if model_info.name == "Qwen/Qwen3-Reranker-4B": vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models( + Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs + ) @pytest.mark.parametrize("model_info", RERANK_MODELS) @multi_gpu_test(num_gpus=2) -def test_rerank_models_mteb_tp(vllm_runner, - model_info: RerankModelInfo) -> None: - +def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None: assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { @@ -109,8 +106,6 @@ def test_rerank_models_mteb_tp(vllm_runner, if model_info.name == "Qwen/Qwen3-Reranker-4B": vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - atol=1.2e-2) + mteb_test_rerank_models( + Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs, atol=1.2e-2 + ) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 3b7fab3ba5c9..9eeac29f9ab6 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -24,10 +24,8 @@ def v1(run_with_both_engines): def math_step_prompts(): # ruff: noqa: E501 data = { - "system": - "Please reason step by step, and put your final answer within \\boxed{}. ", - "query": - "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", + "system": "Please reason step by step, and put your final answer within \\boxed{}. ", + "query": "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", "response": [ "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.", "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.", @@ -35,16 +33,16 @@ def math_step_prompts(): "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).", ], } - answer = "".join(data['response']) + "" + answer = "".join(data["response"]) + "" prompt = f"system\n{data['system']}\nuser\n{data['query']}\nassistant\n{answer}<|endoftext|>" return [prompt] def step_reward_patch_hf_model(hf_model: HfRunner): - # Patch the hf_runner to use the step reward function - def make_step_rewards(logits: torch.Tensor, - token_masks: torch.Tensor) -> list[list[float]]: + def make_step_rewards( + logits: torch.Tensor, token_masks: torch.Tensor + ) -> list[list[float]]: probabilities = F.softmax(logits, dim=-1) probabilities = probabilities * token_masks.unsqueeze(-1) @@ -62,7 +60,7 @@ def reward(prompts: list[str]) -> list[list[float]]: outputs = hf_model.model(input_ids=input_ids) step_sep_id = hf_model.tokenizer.encode("")[0] - token_masks = (input_ids == step_sep_id) + token_masks = input_ids == step_sep_id return make_step_rewards(outputs[0], token_masks) hf_model.reward = reward # type: ignore[attr-defined] @@ -73,8 +71,10 @@ def reward(prompts: list[str]) -> list[list[float]]: @pytest.mark.parametrize( "model", [ - pytest.param("Qwen/Qwen2.5-Math-PRM-7B", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "Qwen/Qwen2.5-Math-PRM-7B", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index c75ff1445616..d220f0ec2597 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -37,8 +37,9 @@ def test_cross_encoder_1_to_1(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict([text_pair]).tolist() - with vllm_runner(model_name, task="score", dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, task="score", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -56,8 +57,9 @@ def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, task="score", dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, task="score", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -76,8 +78,9 @@ def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, task="score", dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, task="score", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 @@ -95,17 +98,15 @@ def emb_model_name(request): def test_embedding_1_to_1(vllm_runner, hf_runner, emb_model_name): text_pair = [TEXTS_1[0], TEXTS_2[0]] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: hf_embeddings = hf_model.encode(text_pair) - hf_outputs = [ - F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0) - ] + hf_outputs = [F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0)] - with vllm_runner(emb_model_name, - task="embed", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, task="embed", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -120,20 +121,18 @@ def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[0], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - task="embed", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, task="embed", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -149,20 +148,18 @@ def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[1], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - task="embed", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, task="embed", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index d6b5dbd08372..5174a481b139 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -7,50 +7,64 @@ from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - enable_test=True), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + enable_test=True, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + enable_test=True, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + enable_test=True, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + enable_test=True, + ), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index 33aff1c873fc..48f92d61ce4e 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -20,51 +20,57 @@ field.""" -def test_smaller_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_smaller_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = 10 - with vllm_runner(model_name, task="embed", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, task="embed", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.model.encode( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == truncate_prompt_tokens -def test_max_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): +def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input_str): truncate_prompt_tokens = -1 - with vllm_runner(model_name, task="embed", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, task="embed", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.model.encode( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == max_model_len -def test_bigger_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_bigger_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = max_model_len + 1 - with pytest.raises(ValueError), vllm_runner( - model_name, task="embed", - max_model_len=max_model_len) as vllm_model: - + with ( + pytest.raises(ValueError), + vllm_runner( + model_name, task="embed", max_model_len=max_model_len + ) as vllm_model, + ): llm_output = vllm_model.model.encode( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) - assert llm_output == f"""truncate_prompt_tokens value + assert ( + llm_output + == f"""truncate_prompt_tokens value ({truncate_prompt_tokens}) is greater than max_model_len ({max_model_len}). Please, select a smaller truncation size.""" + ) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 98461676aa47..2b98e1530474 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -3,27 +3,41 @@ """Common tests for testing .generate() functionality for single / multiple image, embedding, and video support for different VLMs in vLLM. """ + import math import os from collections import defaultdict from pathlib import PosixPath import pytest -from transformers import (AutoModel, AutoModelForImageTextToText, - AutoModelForTextToWaveform, AutoModelForVision2Seq) +from transformers import ( + AutoModel, + AutoModelForImageTextToText, + AutoModelForTextToWaveform, + AutoModelForVision2Seq, +) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, - ImageTestAssets, VideoTestAssets, VllmRunner) -from ....utils import (create_new_process_for_each_test, large_gpu_mark, - multi_gpu_marks) +from ....conftest import ( + IMAGE_ASSETS, + AudioTestAssets, + HfRunner, + ImageTestAssets, + VideoTestAssets, + VllmRunner, +) +from ....utils import create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks from ...utils import check_outputs_equal from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils.case_filtering import get_parametrized_options -from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs, - VLMTestInfo, VLMTestType) +from .vlm_utils.types import ( + CustomTestOptions, + ExpandableVLMTestArgs, + VLMTestInfo, + VLMTestType, +) # This hack is needed for phi3v & paligemma models # ROCm Triton FA can run into shared memory issues with these models, @@ -736,7 +750,7 @@ def _mark_splits( new_test_settings = dict[str, VLMTestInfo]() for i in range(num_groups): - models_in_group = models[i * split_size:(i + 1) * split_size] + models_in_group = models[i * split_size : (i + 1) * split_size] for model in models_in_group: for info in test_infos_by_model[model]: @@ -767,12 +781,17 @@ def _mark_splits( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=False, - )) -def test_single_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): + ), +) +def test_single_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -792,12 +811,17 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=False, - )) -def test_multi_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): + ), +) +def test_multi_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -817,12 +841,16 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=False, - )) -def test_image_embedding_models(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): + ), +) +def test_image_embedding_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -841,10 +869,16 @@ def test_image_embedding_models(model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=False, - )) -def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): + ), +) +def test_video_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -863,10 +897,16 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=False, - )) -def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): + ), +) +def test_audio_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -885,7 +925,8 @@ def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=False, - )) + ), +) def test_custom_inputs_models( model_type: str, test_case: ExpandableVLMTestArgs, @@ -911,13 +952,18 @@ def test_custom_inputs_models( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): +def test_single_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -937,13 +983,18 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): +def test_multi_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -963,14 +1014,17 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_image_embedding_models_heavy(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - monkeypatch): +def test_image_embedding_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -989,11 +1043,16 @@ def test_image_embedding_models_heavy(model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=True, - )) -def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): + ), +) +def test_video_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -1012,11 +1071,16 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=True, - )) -def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): + ), +) +def test_audio_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, + monkeypatch, +): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -1035,7 +1099,8 @@ def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_custom_inputs_models_heavy( model_type: str, diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py index a622957f96f6..92a993d157b0 100644 --- a/tests/models/multimodal/generation/test_florence2.py +++ b/tests/models/multimodal/generation/test_florence2.py @@ -17,12 +17,12 @@ # Florence-2 model repo's tokenizer config is missing some special tokens. # Therefore, we use a converted tokenizer from a forked repo TOKENIZER = "Isotr0py/Florence-2-tokenizer" -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "", # special task token which will output special tokens - "cherry_blossom": - "Describe in detail what is shown in the image.", -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "", # special task token which will output special tokens + "cherry_blossom": "Describe in detail what is shown in the image.", + } +) def get_hf_images_prompts( @@ -35,13 +35,13 @@ def get_hf_images_prompts( ExplicitEncoderDecoderPrompt( encoder_prompt=encoder_prompt["prompt"], decoder_prompt=None, - )) + ) + ) images.append(encoder_prompt["multi_modal_data"]["image"]) return prompts, images -def hf_to_vllm_output(hf_output: tuple[list[int], str, - Optional[SampleLogprobs]]): +def hf_to_vllm_output(hf_output: tuple[list[int], str, Optional[SampleLogprobs]]): """Sanitize hf output to be comparable with vllm output.""" output_ids, output_str, out_logprobs = hf_output @@ -62,35 +62,39 @@ def run_test( tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ) -> None: - with vllm_runner(model, - max_num_seqs=8, - tokenizer_name=TOKENIZER, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + max_num_seqs=8, + tokenizer_name=TOKENIZER, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + ) as vllm_model: vllm_outputs_per_case = [ vllm_model.generate_encoder_decoder_greedy_logprobs( prompts, max_tokens, num_logprobs=num_logprobs, skip_special_tokens=False, - ) for prompts in inputs + ) + for prompts in inputs ] hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.lm_head + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.lm_head + ) hf_outputs_per_case = [ hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, max_tokens, num_logprobs=num_logprobs, images=images) + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in hf_inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs], outputs_1_lst=vllm_outputs, @@ -120,20 +124,31 @@ def run_test( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, model: str, - size_factors: list[int], dtype: str, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, + model: str, + size_factors: list[int], + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [[ - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=prompt, - multi_modal_data={"image": rescale_image_size(image, factor)}), - decoder_prompt=None, - ) for factor in size_factors - ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + [ + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=prompt, + multi_modal_data={"image": rescale_image_size(image, factor)}, + ), + decoder_prompt=None, + ) + for factor in size_factors + ] + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index c5ffa5f3a70a..563626961106 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -10,8 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SampleLogprobs -from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, - VllmRunner) +from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -64,50 +63,49 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"audio": 1}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=64, - enforce_eager=True, + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"audio": 1}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=64, + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("audio", 1, audio_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + lora_request=lora_request, + ) for prompts, audios in inputs ] - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: - + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=[audios], - eos_token_id=eos_token_id) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=[audios], + eos_token_id=eos_token_id, + ) for prompts, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(output) for output in vllm_outputs - ], + outputs_1_lst=[vllm_to_hf_output(output) for output in vllm_outputs], name_0="hf", name_1="vllm", ) @@ -118,9 +116,16 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, - audio_assets: AudioTestAssets, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + model: str, + audio_assets: AudioTestAssets, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index 949c0a80d31b..2aa2cb757e57 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -28,8 +28,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: give the same result. """ - image_cherry = convert_image_mode( - ImageAsset("cherry_blossom").pil_image, "RGB") + image_cherry = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB") images = [image_cherry, image_stop] video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays @@ -47,29 +46,30 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: ), ] - with vllm_runner(model, - task="generate", - dtype=dtype, - limit_mm_per_prompt={"image": 2}, - max_model_len=32768, - max_num_seqs=2, - tensor_parallel_size=1, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + task="generate", + dtype=dtype, + limit_mm_per_prompt={"image": 2}, + max_model_len=32768, + max_num_seqs=2, + tensor_parallel_size=1, + enforce_eager=True, + ) as vllm_model: vllm_outputs_per_case = [ - vllm_model.generate_greedy(prompts, - max_tokens, - images=images, - videos=videos) + vllm_model.generate_greedy( + prompts, max_tokens, images=images, videos=videos + ) for prompts, images, videos in inputs ] all_results = [output[0][1] for output in vllm_outputs_per_case] - outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n")) - for total_str in all_results] - prompt_lengths = [prompt_len for _, prompt_len in outputs] - generated_strs = [ - total_str[prompt_len:] for total_str, prompt_len in outputs + outputs = [ + (total_str, total_str.find("assistant\n") + len("assistant\n")) + for total_str in all_results ] + prompt_lengths = [prompt_len for _, prompt_len in outputs] + generated_strs = [total_str[prompt_len:] for total_str, prompt_len in outputs] interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths interleaved_output_str, noninterleaved_output_str = generated_strs diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index 2bb01e494d43..deb9fea82bc9 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -9,17 +9,24 @@ from vllm import LLM, SamplingParams from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) +from vllm.attention.selector import ( + _Backend, + _cached_get_attn_backend, + global_force_attn_backend_context_manager, +) from vllm.model_executor.models.mllama import MllamaForConditionalGeneration from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + ImageTestAssets, + PromptImageInput, + VllmRunner, +) from ....quantization.utils import is_quant_method_supported -from ....utils import (create_new_process_for_each_test, large_gpu_test, - multi_gpu_test) +from ....utils import create_new_process_for_each_test, large_gpu_test, multi_gpu_test from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 3 @@ -27,12 +34,12 @@ LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|image|><|begin_of_text|>The meaning of the image is", - "cherry_blossom": - "<|image|><|begin_of_text|>The city is", -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|image|><|begin_of_text|>The meaning of the image is", + "cherry_blossom": "<|image|><|begin_of_text|>The city is", + } +) text_only_prompts = [ "The color of the sky is blue but sometimes it can also be", @@ -43,32 +50,57 @@ ] # Indices for inputs -TEXT_ONLY = '0' -IMAGE_AT_BEG = '1' -IMAGE_AT_MIDDLE = '2' -TWO_IMAGES = '3' +TEXT_ONLY = "0" +IMAGE_AT_BEG = "1" +IMAGE_AT_MIDDLE = "2" +TWO_IMAGES = "3" # Input tokenized prompt_data = { # Tell me a story TEXT_ONLY: [41551, 757, 264, 3446], # <|image|> What's the content of this image - IMAGE_AT_BEG: - [MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220], + IMAGE_AT_BEG: [MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220], # Hello <|image|>What' the content of this image - IMAGE_AT_MIDDLE: - [9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217], - #<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501 + IMAGE_AT_MIDDLE: [ + 9906, + 220, + MLLAMA_IMAGE_TOKEN_ID, + 3923, + 6, + 279, + 2262, + 315, + 420, + 2217, + ], + # <|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501 TWO_IMAGES: [ - MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30, - MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30 - ] + MLLAMA_IMAGE_TOKEN_ID, + 3957, + 1070, + 264, + 37085, + 304, + 420, + 2217, + 30, + MLLAMA_IMAGE_TOKEN_ID, + 3923, + 596, + 279, + 10065, + 304, + 420, + 2217, + 30, + ], } -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str +): """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -79,7 +111,8 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] @@ -99,24 +132,28 @@ def _get_inputs( images = [asset.pil_image for asset in image_assets] if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] elif sizes is not None: - inputs_per_image = [( - [ - prompt if size is not None else text_only_prompts[0] - for size in sizes - ], - [ - image.resize(size) if size is not None else None - for size in sizes - ], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [ + prompt if size is not None else text_only_prompts[0] + for size in sizes + ], + [image.resize(size) if size is not None else None for size in sizes], + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] if len(sizes) == 0: inputs_per_image.append( - (text_only_prompts, [None] * len(text_only_prompts))) + (text_only_prompts, [None] * len(text_only_prompts)) + ) else: raise ValueError("You must provide either `size_factors` or `sizes`") @@ -136,8 +173,7 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, -): - ... +): ... @overload @@ -153,8 +189,7 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, -): - ... +): ... def run_test( @@ -200,7 +235,7 @@ def _run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -212,41 +247,39 @@ def _run_test( # max_model_len should be greater than image_feature_size with vllm_runner( - model, - dtype=dtype, - max_model_len=19212, # 3 max size images - max_num_seqs=3, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: + model, + dtype=dtype, + max_model_len=19212, # 3 max size images + max_num_seqs=3, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT}, + ) as vllm_model: vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs ] - with hf_runner(model, - dtype=dtype, - model_kwargs={"device_map": "auto"}, - auto_cls=AutoModelForImageTextToText) as hf_model: + with hf_runner( + model, + dtype=dtype, + model_kwargs={"device_map": "auto"}, + auto_cls=AutoModelForImageTextToText, + ) as hf_model: hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + hf_model.generate_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs + vllm_to_hf_output(vllm_output, model) for vllm_output in vllm_outputs ], name_0="hf", name_1="vllm", @@ -273,26 +306,51 @@ def clear_cache(): # Single-size, batched [(512, 512), (512, 512), (512, 512)], # Multi-size, batched - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028)], + [ + (512, 512), + (1024, 512), + (1536, 512), + (2048, 512), + (512, 1024), + (1024, 1024), + (512, 1536), + (512, 2028), + ], # Multi-size, batched, including text only - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028), None], + [ + (512, 512), + (1024, 512), + (1536, 512), + (2048, 512), + (512, 1024), + (1024, 1024), + (512, 1536), + (512, 2028), + None, + ], # mllama has 8 possible aspect ratios, carefully set the sizes # to cover all of them - ]) + ], +) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, - model, sizes, dtype, max_tokens, - num_logprobs, - attn_backend: _Backend) -> None: +def test_models_single_leading_image( + hf_runner, + vllm_runner, + image_assets, + model, + sizes, + dtype, + max_tokens, + num_logprobs, + attn_backend: _Backend, +) -> None: with global_force_attn_backend_context_manager(attn_backend): if attn_backend == _Backend.FLASH_ATTN: # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' + dtype = "bfloat16" run_test( hf_runner, vllm_runner, @@ -313,36 +371,45 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - +def test_models_multi_leading_images( + hf_runner, + vllm_runner, + image_assets, + model, + dtype, + max_tokens, + num_logprobs, + attn_backend: _Backend, +) -> None: stop_sign = image_assets[0].pil_image cherry_blossom = image_assets[1].pil_image - inputs = [( - [ - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 - ], - [ - [stop_sign, cherry_blossom], - # Images with different sizes. + inputs = [ + ( [ - stop_sign.resize((512, 512)), - stop_sign, + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 ], [ - stop_sign, - stop_sign.resize((512, 1536)), - cherry_blossom.resize((512, 1024)), + [stop_sign, cherry_blossom], + # Images with different sizes. + [ + stop_sign.resize((512, 512)), + stop_sign, + ], + [ + stop_sign, + stop_sign.resize((512, 1536)), + cherry_blossom.resize((512, 1024)), + ], ], - ])] + ) + ] with global_force_attn_backend_context_manager(attn_backend): if attn_backend == _Backend.FLASH_ATTN: # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' + dtype = "bfloat16" _run_test( hf_runner, vllm_runner, @@ -362,27 +429,36 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, - dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - +def test_models_interleaved_images( + hf_runner, + vllm_runner, + image_assets, + model, + dtype, + max_tokens, + num_logprobs, + attn_backend: _Backend, +) -> None: stop_sign = image_assets[0].pil_image cherry_blossom = image_assets[1].pil_image - inputs = [( - [ - "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 - "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 - "which is a stop sign and which is a cherry blossom?", # noqa: E501 - ], - [ - [stop_sign], - [stop_sign, cherry_blossom], - ])] + inputs = [ + ( + [ + "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 + "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 + "which is a stop sign and which is a cherry blossom?", # noqa: E501 + ], + [ + [stop_sign], + [stop_sign, cherry_blossom], + ], + ) + ] with global_force_attn_backend_context_manager(attn_backend): if attn_backend == _Backend.FLASH_ATTN: # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' + dtype = "bfloat16" _run_test( hf_runner, vllm_runner, @@ -431,8 +507,10 @@ def test_models_distributed( @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) def test_bnb_regression( image_assets: ImageTestAssets, model: str, @@ -443,13 +521,10 @@ def test_bnb_regression( prompts = [ { "prompt": "<|begin_of_text|>The content of the image <|image|> is", - "multi_modal_data": { - "image": stop_sign - }, + "multi_modal_data": {"image": stop_sign}, }, { - "prompt": - "The color of the sky is blue but sometimes it can also be", + "prompt": "The color of the sky is blue but sometimes it can also be", }, ] # Test regression about QKVCrossParallelLinear @@ -519,8 +594,8 @@ def test_explicit_implicit_prompt( ) outputs = llm.generate(prompts, sampling_params) n_prompts = len(prompts) - explicit_outputs = outputs[:n_prompts // 2] - implicit_outputs = outputs[n_prompts // 2:] + explicit_outputs = outputs[: n_prompts // 2] + implicit_outputs = outputs[n_prompts // 2 :] for exp_output, imp_output in zip(explicit_outputs, implicit_outputs): assert exp_output.outputs[0].text == imp_output.outputs[0].text @@ -532,20 +607,28 @@ def test_explicit_implicit_prompt( @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, - num_logprobs, attn_backend: _Backend) -> None: - +def test_regression( + vllm_runner, + image_assets, + model, + dtype, + max_tokens, + num_logprobs, + attn_backend: _Backend, +) -> None: stop_sign = image_assets[0].pil_image - with global_force_attn_backend_context_manager(attn_backend), vllm_runner( + with ( + global_force_attn_backend_context_manager(attn_backend), + vllm_runner( model, dtype=dtype, max_model_len=8192, max_num_seqs=4, tensor_parallel_size=1, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: - + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT}, + ) as vllm_model, + ): # Regression tests for https://github.com/vllm-project/vllm/issues/10648 # Number of groups of image tokens is greater than the number of images @@ -553,10 +636,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501 image = stop_sign with pytest.raises(ValueError): - vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs, - images=[image]) + vllm_model.generate_greedy_logprobs( + [prompt], max_tokens, num_logprobs, images=[image] + ) # Batch of a text-only and image request that requires cross-attention prompts = [ @@ -567,10 +649,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, None, [stop_sign], ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs, images=images + ) # Test the reverse order too for good measure prompts = [ @@ -581,10 +662,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, [stop_sign], None, ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs, images=images + ) # Mixed batch with text and images with different numbers of tiles prompts = [ @@ -598,10 +678,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, # smaller image must be 2nd for the repro [stop_sign.resize((448, 448))], ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs, images=images + ) class DummyModel: @@ -612,22 +691,25 @@ class DummyModel: @pytest.mark.parametrize( "input_indices_and_output", # inputs, (cross_attention_mask, kv_range_for_decode) - [([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)), - ([TEXT_ONLY, IMAGE_AT_BEG], (None, None)), - ([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - ((23, 24), [[0, 6], [6, 12]])), - ([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])), - ([TWO_IMAGES], ((18, 12), [[6, 12]])), - ([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))]) + [ + ([TEXT_ONLY], (None, None)), + ([IMAGE_AT_BEG], (None, None)), + ([TEXT_ONLY, IMAGE_AT_BEG], (None, None)), + ([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])), + ([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])), + ([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], ((23, 24), [[0, 6], [6, 12]])), + ([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])), + ([TWO_IMAGES], ((18, 12), [[6, 12]])), + ([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]])), + ], +) def test_get_cross_attention_mask(input_indices_and_output) -> None: - input_indices, expected_output = input_indices_and_output sequences = [torch.tensor(prompt_data[i]) for i in input_indices] - num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices - if i != TEXT_ONLY] + num_tiles = [ + [2, 2] if i != TEXT_ONLY else [] for i in input_indices if i != TEXT_ONLY + ] input = torch.cat(sequences) seq_lens = [len(s) for s in sequences] @@ -651,16 +733,18 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None: dummy = DummyModel() - cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ - .get_cross_attention_mask(dummy, - input, - attn_data, - num_tiles=num_tiles, - num_tokens_per_tile=3, - dtype=torch.bfloat16) + cross_attention_mask, kv_range_for_decode = ( + MllamaForConditionalGeneration.get_cross_attention_mask( + dummy, + input, + attn_data, + num_tiles=num_tiles, + num_tokens_per_tile=3, + dtype=torch.bfloat16, + ) + ) - expected_cross_attention_mask, expected_kv_range_for_decode = \ - expected_output + expected_cross_attention_mask, expected_kv_range_for_decode = expected_output assert kv_range_for_decode == expected_kv_range_for_decode if expected_cross_attention_mask is not None: @@ -673,11 +757,19 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None: @pytest.mark.core_model @pytest.mark.parametrize( "input_indices", - [[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE], - [TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - [IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]]) + [ + [TEXT_ONLY], + [IMAGE_AT_BEG], + [TEXT_ONLY, IMAGE_AT_BEG], + [IMAGE_AT_MIDDLE], + [TEXT_ONLY, IMAGE_AT_MIDDLE], + [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], + [IMAGE_AT_MIDDLE, TEXT_ONLY], + [TWO_IMAGES], + [TEXT_ONLY, TWO_IMAGES], + ], +) def test_get_full_text_row_masked_out_mask(input_indices) -> None: - sequences = [torch.tensor(prompt_data[i]) for i in input_indices] seq_lens = [len(s) for s in sequences] @@ -708,10 +800,11 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: dummy = DummyModel() - full_text_row_masked_out_mask = MllamaForConditionalGeneration\ - .get_full_text_row_masked_out_mask(dummy, - attn_data, - torch.get_default_device()) + full_text_row_masked_out_mask = ( + MllamaForConditionalGeneration.get_full_text_row_masked_out_mask( + dummy, attn_data, torch.get_default_device() + ) + ) full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze() full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist() @@ -721,30 +814,33 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: for i, seq_len in enumerate(seq_lens): must_be_masked = input_indices[i] != TEXT_ONLY for _ in range(seq_len): - assert full_text_row_masked_out_mask[idx] == must_be_masked, \ - f"full_text_row_masked_out_mask[{idx}] must be " \ - f"'{must_be_masked}' " + assert full_text_row_masked_out_mask[idx] == must_be_masked, ( + f"full_text_row_masked_out_mask[{idx}] must be '{must_be_masked}' " + ) idx += 1 @pytest.mark.core_model -@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [ - ([6404], [[4]], [6404]), - ([0, 6404], [[4]], [6404]), - ([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]), - ([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]), -]) -def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles, - expected) -> None: - +@pytest.mark.parametrize( + "encoder_seq_lens, num_tiles, expected", + [ + ([6404], [[4]], [6404]), + ([0, 6404], [[4]], [6404]), + ([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]), + ([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]), + ], +) +def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles, expected) -> None: dummy = DummyModel() num_tokens_per_tile = 1601 - actual_encoder_seq_lens = MllamaForConditionalGeneration \ - ._get_and_validate_encoder_lens( + actual_encoder_seq_lens = ( + MllamaForConditionalGeneration._get_and_validate_encoder_lens( dummy, encoder_seq_lens, num_tiles, num_tokens_per_tile, ) - assert actual_encoder_seq_lens == expected, \ + ) + assert actual_encoder_seq_lens == expected, ( f"Expected {expected} but got {actual_encoder_seq_lens}" + ) diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 4e8465778e25..2cf6e347a126 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -17,31 +17,39 @@ from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptAudioInput, + PromptImageInput, + VllmRunner, +) from ....utils import large_gpu_test from ...utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join(model_path, "examples", - "what_is_shown_in_this_image.wav") +speech_question = os.path.join( + model_path, "examples", "what_is_shown_in_this_image.wav" +) models = [model_path] -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str +): """Sanitize vllm output to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -71,8 +79,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str, def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, - Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -98,27 +105,29 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request, + ) for prompts, images, audios in inputs ] @@ -127,42 +136,36 @@ def run_test( pytest.skip("HF impl is not compatible with current transformers") hf_model_kwargs = {"_attn_implementation": "sdpa"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - + with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id - def patch_hf_processor(*args, - text="", - images=None, - audio=None, - sampling_rate=None, - **kwargs): + def patch_hf_processor( + *args, text="", images=None, audio=None, sampling_rate=None, **kwargs + ): audios = None if audio is not None and sampling_rate is not None: audios = [(audio, sampling_rate)] - return hf_processor(*args, - text=text, - images=images, - audios=audios, - **kwargs) + return hf_processor( + *args, text=text, images=images, audios=audios, **kwargs + ) hf_model.processor = patch_hf_processor hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id, - num_logits_to_keep=0) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id, + num_logits_to_keep=0, + ) for prompts, images, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, @@ -189,16 +192,27 @@ def patch_hf_processor(*args, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, @@ -233,16 +247,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_multi_images_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ ( [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], None, ), ] @@ -266,10 +290,15 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, - max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: - +def test_vision_speech_models( + hf_runner, + vllm_runner, + model, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=None) image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 1def825ab087..42fbfb99264b 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -38,33 +38,33 @@ def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "text": PROMPT, - }] + [{ - "type": "image_url", - "image_url": { - "url": url - } - } for url in urls], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": PROMPT, + } + ] + + [{"type": "image_url", "image_url": {"url": url}} for url in urls], + } + ] def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "content": PROMPT, - }, *({ - "type": "image", - "image": download_image(url) - } for url in urls)], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "content": PROMPT, + }, + *({"type": "image", "image": download_image(url)} for url in urls), + ], + } + ] def _create_engine_inputs(urls: list[str]) -> TokensPrompt: @@ -137,11 +137,17 @@ def _dump_outputs_w_logprobs( outputs: OutputsLogprobs, filename: "StrPath", ) -> None: - json_data = [(tokens, text, [{ - k: asdict(v) - for k, v in token_logprobs.items() - } for token_logprobs in (logprobs or [])]) - for tokens, text, logprobs in outputs] + json_data = [ + ( + tokens, + text, + [ + {k: asdict(v) for k, v in token_logprobs.items()} + for token_logprobs in (logprobs or []) + ], + ) + for tokens, text, logprobs in outputs + ] with open(filename, "w") as f: json.dump(json_data, f) @@ -151,10 +157,17 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: with open(filename, "rb") as f: json_data = json.load(f) - return [(tokens, text, [{ - int(k): Logprob(**v) - for k, v in token_logprobs.items() - } for token_logprobs in logprobs]) for tokens, text, logprobs in json_data] + return [ + ( + tokens, + text, + [ + {int(k): Logprob(**v) for k, v in token_logprobs.items()} + for token_logprobs in logprobs + ], + ) + for tokens, text, logprobs in json_data + ] @large_gpu_test(min_gb=80) @@ -167,21 +180,19 @@ def test_chat( model: str, dtype: str, ) -> None: - EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( - FIXTURE_LOGPROBS_CHAT[model]) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", - max_model_len=max_model_len, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] for msg in MSGS: - output = vllm_model.model.chat(msg, - sampling_params=SAMPLING_PARAMS) + output = vllm_model.model.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) @@ -190,46 +201,58 @@ def test_chat( for i in range(len(logprobs)): assert logprobs[i][-1] is None logprobs[i] = logprobs[i][:-1] - check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, - outputs_1_lst=logprobs, - name_0="h100_ref", - name_1="output") + check_logprobs_close( + outputs_0_lst=EXPECTED_CHAT_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output", + ) @large_gpu_test(min_gb=48) -@pytest.mark.parametrize("prompt,expected_ranges", - [(_create_engine_inputs_hf(IMG_URLS[:1]), - [PlaceholderRange(offset=11, length=494)]), - (_create_engine_inputs_hf(IMG_URLS[1:4]), [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ])]) -def test_multi_modal_placeholders(vllm_runner, prompt, - expected_ranges: list[PlaceholderRange], - monkeypatch) -> None: - +@pytest.mark.parametrize( + "prompt,expected_ranges", + [ + ( + _create_engine_inputs_hf(IMG_URLS[:1]), + [PlaceholderRange(offset=11, length=494)], + ), + ( + _create_engine_inputs_hf(IMG_URLS[1:4]), + [ + PlaceholderRange(offset=11, length=266), + PlaceholderRange(offset=277, length=1056), + PlaceholderRange(offset=1333, length=418), + ], + ), + ], +) +def test_multi_modal_placeholders( + vllm_runner, prompt, expected_ranges: list[PlaceholderRange], monkeypatch +) -> None: # This placeholder checking test only works with V0 engine # where `multi_modal_placeholders` is returned with `RequestOutput` monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner( - "mistral-community/pixtral-12b", - max_model_len=8192, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + "mistral-community/pixtral-12b", + max_model_len=8192, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = vllm_model.model.generate(prompt) assert len(outputs) == 1, f"{len(outputs)=}" output: RequestOutput = outputs[0] - assert hasattr(output, - "multi_modal_placeholders"), f"{output.__dict__=}" - assert "image" in output.multi_modal_placeholders, \ + assert hasattr(output, "multi_modal_placeholders"), f"{output.__dict__=}" + assert "image" in output.multi_modal_placeholders, ( f"{output.multi_modal_placeholders.keys()=}" - image_placeholder_ranges: list[ - PlaceholderRange] = output.multi_modal_placeholders["image"] - assert len(image_placeholder_ranges) == len( - expected_ranges), f"{image_placeholder_ranges=}" - for real_range, expected_range in zip(image_placeholder_ranges, - expected_ranges): - assert real_range == expected_range, \ - f"{real_range=} {expected_range=}" + ) + image_placeholder_ranges: list[PlaceholderRange] = ( + output.multi_modal_placeholders["image"] + ) + assert len(image_placeholder_ranges) == len(expected_ranges), ( + f"{image_placeholder_ranges=}" + ) + for real_range, expected_range in zip( + image_placeholder_ranges, expected_ranges + ): + assert real_range == expected_range, f"{real_range=} {expected_range=}" diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index a2793b8c8ddf..46acaabed5b2 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -11,8 +11,13 @@ from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video -from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, - PromptVideoInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + VIDEO_ASSETS, + PromptImageInput, + PromptVideoInput, + VllmRunner, +) from ...utils import check_logprobs_close @@ -21,7 +26,7 @@ def use_v0_only(monkeypatch): """ V1 Test: batch_make_xxxxx_embeddings calls a V0 internal """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") models = ["Qwen/Qwen2-VL-2B-Instruct"] @@ -36,28 +41,29 @@ def qwen2_vl_chat_template(*query): return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 -IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the biggest text's content in this image?", - ), - "cherry_blossom": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the season shown in this image? ", - "Reply with a short sentence (no more than 20 words)", - ), -}) - -VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "baby_reading": - qwen2_vl_chat_template( - VIDEO_PLACEHOLDER, - "Describe this video with a short sentence ", - "(no more than 20 words)", - ), -}) +IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the biggest text's content in this image?", + ), + "cherry_blossom": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the season shown in this image? ", + "Reply with a short sentence (no more than 20 words)", + ), + } +) + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen2_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) MULTIIMAGE_PROMPT = qwen2_vl_chat_template( IMAGE_PLACEHOLDER, @@ -79,17 +85,19 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict): def batch_make_image_embeddings( - image_batches: list[Union[Image.Image, list[Image.Image]]], processor, - llm: VllmRunner) -> list[Qwen2VLPromptImageEmbeddingInput]: + image_batches: list[Union[Image.Image, list[Image.Image]]], + processor, + llm: VllmRunner, +) -> list[Qwen2VLPromptImageEmbeddingInput]: """batched image embeddings for Qwen2-VL - This will infer all images' embeddings in a single batch, + This will infer all images' embeddings in a single batch, and split the result according to input batches. image_batches: - Single-image batches: `list[Image.Image]` - Multiple-image batches: `list[list[Image.Image]]]` - + returns: `list[Qwen2VLPromptImageEmbeddingInput]` """ @@ -110,9 +118,9 @@ def batch_make_image_embeddings( # image to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=images, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=images, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] @@ -121,12 +129,11 @@ def get_image_embeds(model): with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - image_grid_thw_on_device = image_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=image_grid_thw_on_device) + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + image_grid_thw_on_device = image_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual(pixel_values_on_device, grid_thw=image_grid_thw_on_device) # V1 Test: this calls a V0 internal. image_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -140,16 +147,21 @@ def get_image_embeds(model): merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count]) + for grid_thw in image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ] + ) - result.append({ - "image_embeds": - image_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "image_grid_thw": - image_grid_thw[image_counter:image_counter + - cur_batch_image_count], - }) + result.append( + { + "image_embeds": image_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "image_grid_thw": image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ], + } + ) embed_counter += cur_batch_embed_len image_counter += cur_batch_image_count @@ -163,13 +175,13 @@ def get_image_embeds(model): def batch_make_video_embeddings( - video_batches: PromptVideoInput, processor, - llm: VllmRunner) -> list[Qwen2VLPromptVideoEmbeddingInput]: + video_batches: PromptVideoInput, processor, llm: VllmRunner +) -> list[Qwen2VLPromptVideoEmbeddingInput]: """batched video embeddings for Qwen2-VL A NDArray represents a single video's all frames. - This will infer all videos' embeddings in a single batch, + This will infer all videos' embeddings in a single batch, and split the result according to input batches. video_batches: @@ -194,9 +206,9 @@ def batch_make_video_embeddings( # video to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=None, videos=videos, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=None, videos=videos, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] @@ -205,12 +217,11 @@ def get_image_embeds(model): with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - video_grid_thw_on_device = video_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=video_grid_thw_on_device) + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + video_grid_thw_on_device = video_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual(pixel_values_on_device, grid_thw=video_grid_thw_on_device) # V1 Test: this calls a V0 internal. video_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -224,16 +235,21 @@ def get_image_embeds(model): merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count]) + for grid_thw in video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ] + ) - result.append({ - "video_embeds": - video_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "video_grid_thw": - video_grid_thw[video_counter:video_counter + - cur_batch_video_count], - }) + result.append( + { + "video_embeds": video_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "video_grid_thw": video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ], + } + ) embed_counter += cur_batch_embed_len video_counter += cur_batch_video_count @@ -266,25 +282,24 @@ def run_embedding_input_test( processor = AutoProcessor.from_pretrained(model) # max_model_len should be greater than image_feature_size - with vllm_runner(model, - task="generate", - max_model_len=4000, - max_num_seqs=3, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - + with vllm_runner( + model, + task="generate", + max_model_len=4000, + max_num_seqs=3, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit, "video": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: outputs_per_case_for_original_input = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images or None, - videos=videos or None) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None, + ) for prompts, images, videos in inputs ] @@ -293,17 +308,19 @@ def run_embedding_input_test( prompts, max_tokens, num_logprobs=num_logprobs, - images=batch_make_image_embeddings( - images, processor, vllm_model) if images else None, - videos=batch_make_video_embeddings( - videos, processor, vllm_model) if videos else None) + images=batch_make_image_embeddings(images, processor, vllm_model) + if images + else None, + videos=batch_make_video_embeddings(videos, processor, vllm_model) + if videos + else None, + ) for prompts, images, videos in inputs ] - for outputs_for_original_input, \ - outputs_for_embeddings_input \ - in zip(outputs_per_case_for_original_input, - outputs_per_case_for_embeddings_input): + for outputs_for_original_input, outputs_for_embeddings_input in zip( + outputs_per_case_for_original_input, outputs_per_case_for_embeddings_input + ): check_logprobs_close( outputs_0_lst=outputs_for_original_input, outputs_1_lst=outputs_for_embeddings_input, @@ -328,18 +345,25 @@ def run_embedding_input_test( @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], [], - ) for image, prompt in zip(images, IMAGE_PROMPTS)] + ) + for image, prompt in zip(images, IMAGE_PROMPTS) + ] run_embedding_input_test( vllm_runner, @@ -370,21 +394,27 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, - model, size_factors, - dtype: str, max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_multiple_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[list[str], PromptImageInput, - PromptVideoInput]] = [( - [MULTIIMAGE_PROMPT for _ in size_factors], - [[ - rescale_image_size(image, factor) - for image in images - ] for factor in size_factors], - [], - )] + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( + [MULTIIMAGE_PROMPT for _ in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], + [], + ) + ] run_embedding_input_test( vllm_runner, @@ -414,22 +444,29 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_video_embeddings_input( + vllm_runner, + video_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: num_frames = 4 sampled_vids = [ sample_frames_from_video(asset.np_ndarrays, num_frames) for asset in video_assets ] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [], [rescale_video_size(video, factor) for factor in size_factors], - ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] + ) + for video, prompt in zip(sampled_vids, VIDEO_PROMPTS) + ] run_embedding_input_test( vllm_runner, diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index e7e7bd3154a1..da1e7c7486fd 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -15,12 +15,12 @@ MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" -AUDIO_PROMPTS = AUDIO_ASSETS.prompts({ - "mary_had_lamb": - "Transcribe this into English.", - "winning_call": - "What is happening in this audio clip?", -}) +AUDIO_PROMPTS = AUDIO_ASSETS.prompts( + { + "mary_had_lamb": "Transcribe this into English.", + "winning_call": "What is happening in this audio clip?", + } +) MULTI_AUDIO_PROMPT = "Describe each of the audios above." @@ -33,7 +33,7 @@ "enable_chunked_prefill": True, "max_num_seqs": 2, # Use a very small limit to exercise chunked prefill. - "max_num_batched_tokens": 16 + "max_num_batched_tokens": 16, } @@ -43,27 +43,33 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: for key, value in params_kwargs.items(): if isinstance(value, bool): if value: - args.append(f"--{key.replace('_','-')}") + args.append(f"--{key.replace('_', '-')}") else: - args.append(f"--{key.replace('_','-')}={value}") + args.append(f"--{key.replace('_', '-')}={value}") return args -@pytest.fixture(params=[ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) +@pytest.fixture( + params=[ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ] +) def server(request, audio_assets: AudioTestAssets): args = [ - "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", "--limit-mm-per-prompt", - json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" + json.dumps({"audio": len(audio_assets)}), + "--trust-remote-code", ] + params_kwargs_to_cli_args(request.param) - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -77,12 +83,11 @@ def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count - return tokenizer.apply_chat_template([{ - 'role': 'user', - 'content': f"{placeholder}{question}" - }], - tokenize=False, - add_generation_prompt=True) + return tokenizer.apply_chat_template( + [{"role": "user", "content": f"{placeholder}{question}"}], + tokenize=False, + add_generation_prompt=True, + ) def run_multi_audio_test( @@ -99,19 +104,21 @@ def run_multi_audio_test( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - with vllm_runner(model, - dtype=dtype, - enforce_eager=True, - limit_mm_per_prompt={ - "audio": - max((len(audio) for _, audio in prompts_and_audios)) - }, - **kwargs) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + enforce_eager=True, + limit_mm_per_prompt={ + "audio": max((len(audio) for _, audio in prompts_and_audios)) + }, + **kwargs, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, num_logprobs=num_logprobs, - audios=[audios for _, audios in prompts_and_audios]) + audios=[audios for _, audios in prompts_and_audios], + ) # The HuggingFace model doesn't support multiple audios yet, so # just assert that some tokens were generated. @@ -122,21 +129,25 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - - vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, - VLLM_PLACEHOLDER) +@pytest.mark.parametrize( + "vllm_kwargs", + [ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ], +) +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, + vllm_kwargs: dict, +) -> None: + vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -149,28 +160,25 @@ def test_models_with_multiple_audios(vllm_runner, async def test_online_serving(client, audio_assets: AudioTestAssets): """Exercises online serving with/without chunked prefill enabled.""" - messages = [{ - "role": - "user", - "content": [ - *[{ - "type": "audio_url", - "audio_url": { - "url": audio.url - } - } for audio in audio_assets], - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] - - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + messages = [ + { + "role": "user", + "content": [ + *[ + {"type": "audio_url", "audio_url": {"url": audio.url}} + for audio in audio_assets + ], + { + "type": "text", + "text": f"What's happening in these {len(audio_assets)} audio clips?", + }, + ], + } + ] + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index b4439dfe020c..aa9628435e4d 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -6,8 +6,12 @@ import pytest import pytest_asyncio from mistral_common.audio import Audio -from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, - TextChunk, UserMessage) +from mistral_common.protocol.instruct.messages import ( + AudioChunk, + RawAudio, + TextChunk, + UserMessage, +) from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -17,8 +21,12 @@ MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507" MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] @@ -30,10 +38,9 @@ def server(request, audio_assets: AudioTestAssets): json.dumps({"audio": len(audio_assets)}), ] + MISTRAL_FORMAT_ARGS - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -64,15 +71,17 @@ def _get_prompt(audio_assets, question): @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -92,23 +101,22 @@ def asset_to_chunk(asset): return audio_dict audio_chunks = [asset_to_chunk(asset) for asset in audio_assets] - messages = [{ - "role": - "user", - "content": [ - *audio_chunks, - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] - - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + messages = [ + { + "role": "user", + "content": [ + *audio_chunks, + { + "type": "text", + "text": f"What's happening in these {len(audio_assets)} audio clips?", + }, + ], + } + ] + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 363d55153aac..7eac8bb1b47a 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -12,8 +12,7 @@ PROMPTS = [ { - "prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", "multi_modal_data": { "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, }, @@ -25,9 +24,8 @@ "audio": AudioAsset("winning_call").audio_and_sample_rate, }, }, - "decoder_prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - } + "decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + }, ] EXPECTED = { @@ -41,7 +39,7 @@ " is June and the third base. They're going to wave him in. The throw" " to the plate will be late. The Mariners are going to play for the" " American League Championship. I don't believe it. It just continues" - " by all five." + " by all five.", ], "openai/whisper-small": [ " The first words I spoke in the original pornograph. A little piece" @@ -51,7 +49,7 @@ " comes joy. Here is Junior to third base. They're gonna wave him" " in. The throw to the plate will be late. The Mariners are going to" " play for the American League Championship. I don't believe it. It" - " just continues. My, oh my." + " just continues. My, oh my.", ], "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" @@ -62,7 +60,7 @@ " Jorgen at third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh" - " my." + " my.", ], "openai/whisper-large-v3": [ " The first words I spoke in the original phonograph, a little piece" @@ -73,7 +71,7 @@ " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." + " my.", ], "openai/whisper-large-v3-turbo": [ " The first words I spoke in the original phonograph, a little piece" @@ -84,8 +82,8 @@ " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." - ] + " my.", + ], } @@ -100,11 +98,11 @@ def run_test( expected_list = EXPECTED[model] * 10 with vllm_runner( - model, - dtype="half", - max_model_len=448, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, + model, + dtype="half", + max_model_len=448, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: llm = vllm_model.model @@ -123,7 +121,8 @@ def run_test( @pytest.mark.core_model @pytest.mark.parametrize( - "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) + "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"] +) @create_new_process_for_each_test() def test_models(vllm_runner, model) -> None: run_test( diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 03c08240d6a8..859c2ffd9df1 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Helpers for building inputs that can be leveraged for different test types. -""" +"""Helpers for building inputs that can be leveraged for different test types.""" + from collections.abc import Iterable from pathlib import PosixPath from typing import Callable, Optional, Union @@ -10,20 +10,30 @@ from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets -from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, - TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, - TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, PromptWithMultiModalInput, SizeType, - VLMTestInfo) - - -def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], - str], - test_placeholder: str) -> str: +from .types import ( + SINGLE_AUDIO_BASE_PROMPT, + SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, + TEST_IMG_PLACEHOLDER, + TEST_VIDEO_PLACEHOLDER, + VIDEO_BASE_PROMPT, + ImageSizeWrapper, + PromptWithMultiModalInput, + SizeType, + VLMTestInfo, +) + + +def replace_test_placeholder( + prompt: str, mm_idx_to_prompt: Callable[[int], str], test_placeholder: str +) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. """ @@ -35,11 +45,13 @@ def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], return img_prompt -def get_model_prompts(base_prompts: Iterable[str], - img_idx_to_prompt: Optional[Callable[[int], str]], - video_idx_to_prompt: Optional[Callable[[int], str]], - audio_idx_to_prompt: Optional[Callable[[int], str]], - prompt_formatter: Callable[[str], str]) -> list[str]: +def get_model_prompts( + base_prompts: Iterable[str], + img_idx_to_prompt: Optional[Callable[[int], str]], + video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], + prompt_formatter: Callable[[str], str], +) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting to get the test prompt string for this model. @@ -56,19 +68,19 @@ def get_model_prompts(base_prompts: Iterable[str], # Replace the multimodal placeholders in the base prompt with # the correct ones for the model that we are testing if img_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - img_idx_to_prompt, - TEST_IMG_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, img_idx_to_prompt, TEST_IMG_PLACEHOLDER + ) if video_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - video_idx_to_prompt, - TEST_VIDEO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER + ) if audio_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - audio_idx_to_prompt, - TEST_AUDIO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, audio_idx_to_prompt, TEST_AUDIO_PLACEHOLDER + ) # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt @@ -84,14 +96,15 @@ def build_single_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build single image inputs") + raise ValueError("Prompt formatter must be set to build single image inputs") - model_prompts = get_model_prompts(test_info.single_image_prompts, - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + test_info.single_image_prompts, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) # For models that require a local path / URL encoded in the image; export # assets and encode into tmp_path for this test. This should be avoided @@ -110,8 +123,8 @@ def build_single_image_inputs_from_test_info( def build_single_image_inputs( - images, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + images, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being @@ -125,7 +138,8 @@ def build_single_image_inputs( apply_image_size_scaling(image, size, size_wrapper.type) for size in size_wrapper.data ], - ) for image, prompt in zip(images, model_prompts) + ) + for image, prompt in zip(images, model_prompts) ] @@ -136,14 +150,15 @@ def build_multi_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build multi image inputs") + raise ValueError("Prompt formatter must be set to build multi image inputs") - model_prompts = get_model_prompts([test_info.multi_image_prompt], - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + [test_info.multi_image_prompt], + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) if test_info.prompt_path_encoder is not None: if tmp_path is None: @@ -164,16 +179,20 @@ def build_multi_image_inputs_from_test_info( def build_multi_image_inputs( - image_lists, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + image_lists, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - image_data=[[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts) + image_data=[ + [ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] + for size in size_wrapper.data + ], + ) + for images, prompt in zip(image_lists, model_prompts) ] @@ -185,10 +204,10 @@ def build_embedding_inputs_from_test_info( # These conditions will always be true if invoked through filtering, # but we still check them in case this is ever called directly if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build image embedding inputs") - if size_wrapper.type != SizeType.SIZE_FACTOR or not \ - all(factor == 1.0 for factor in size_wrapper.data): + raise ValueError("Prompt formatter must be set to build image embedding inputs") + if size_wrapper.type != SizeType.SIZE_FACTOR or not all( + factor == 1.0 for factor in size_wrapper.data + ): raise ValueError("Embedding tests require constant (1.0) size factors") if test_info.convert_assets_to_embeddings is None: raise ValueError("No conversion func for getting embeddings found") @@ -209,8 +228,7 @@ def build_embedding_inputs_from_test_info( assert len(images) == len(model_prompts) inputs = build_single_image_inputs(images, model_prompts, size_wrapper) - vllm_embeddings = build_single_image_inputs(embeds, model_prompts, - size_wrapper) + vllm_embeddings = build_single_image_inputs(embeds, model_prompts, size_wrapper) return inputs, vllm_embeddings @@ -235,21 +253,22 @@ def build_video_inputs_from_test_info( for asset in video_assets ] - video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE - else rescale_video_size) + video_scaler = ( + resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size + ) return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - video_data=[ - video_scaler(video, size) for size in size_wrapper.data - ], - ) for video, prompt in zip(sampled_vids, model_prompts) + video_data=[video_scaler(video, size) for size in size_wrapper.data], + ) + for video, prompt in zip(sampled_vids, model_prompts) ] -def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], - size_type: SizeType): +def apply_image_size_scaling( + image, size: Union[float, tuple[int, int]], size_type: SizeType +): """Applies a size scaler to one image; this can be a an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we @@ -285,13 +304,16 @@ def build_audio_inputs_from_test_info( method="librosa", ) audios = [asset.audio_and_sample_rate for asset in audio_assets] - resampled_audios = [( - resampler.resample( - audio, - orig_sr=sr, - ), - int(resampler.target_sr), - ) for audio, sr in audios] + resampled_audios = [ + ( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) + for audio, sr in audios + ] return [ PromptWithMultiModalInput( diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 336e2dd2b120..fe36dfbf26f6 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -4,19 +4,28 @@ modality, getting all combinations (similar to pytest's parametrization), handling multimodal placeholder substitution, and so on. """ + import itertools from collections import OrderedDict from collections.abc import Iterable import pytest -from .types import (EMBEDDING_SIZE_FACTORS, ExpandableVLMTestArgs, - ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) +from .types import ( + EMBEDDING_SIZE_FACTORS, + ExpandableVLMTestArgs, + ImageSizeWrapper, + SizeType, + VLMTestInfo, + VLMTestType, +) def get_filtered_test_settings( - test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, - new_proc_per_test: bool) -> dict[str, VLMTestInfo]: + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + new_proc_per_test: bool, +) -> dict[str, VLMTestInfo]: """Given the dict of potential test settings to run, return a subdict of tests who have the current test type enabled with the matching val for fork_per_test. @@ -25,7 +34,8 @@ def get_filtered_test_settings( def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): return test_info.test_type == test_type or ( isinstance(test_info.test_type, Iterable) - and test_type in test_info.test_type) + and test_type in test_info.test_type + ) matching_tests = {} for test_name, test_info in test_settings.items(): @@ -36,62 +46,69 @@ def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): assert test_info.convert_assets_to_embeddings is not None # Custom test inputs need to explicitly define the mm limit/inputs if matches_test_type(test_info, VLMTestType.CUSTOM_INPUTS): - assert (test_info.custom_test_opts is not None - and isinstance(test_info.custom_test_opts, Iterable)) + assert test_info.custom_test_opts is not None and isinstance( + test_info.custom_test_opts, Iterable + ) # For all types besides custom inputs, we need a prompt formatter else: assert test_info.prompt_formatter is not None # Everything looks okay; keep if this is has correct proc handling - if (test_info.distributed_executor_backend - is not None) == new_proc_per_test: + if ( + test_info.distributed_executor_backend is not None + ) == new_proc_per_test: matching_tests[test_name] = test_info return matching_tests -def get_parametrized_options(test_settings: dict[str, VLMTestInfo], - test_type: VLMTestType, - create_new_process_for_each_test: bool): +def get_parametrized_options( + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + create_new_process_for_each_test: bool, +): """Converts all of our VLMTestInfo into an expanded list of parameters. This is similar to nesting pytest parametrize calls, but done directly through an itertools product so that each test can set things like size factors etc, while still running in isolated test cases. """ matching_tests = get_filtered_test_settings( - test_settings, test_type, create_new_process_for_each_test) + test_settings, test_type, create_new_process_for_each_test + ) # Ensure that something is wrapped as an iterable it's not already - ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, ) + ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) def get_model_type_cases(model_type: str, test_info: VLMTestInfo): # This is essentially the same as nesting a bunch of mark.parametrize # decorators, but we do it programmatically to allow overrides for on # a per-model basis, while still being able to execute each of these # as individual test cases in pytest. - iter_kwargs = OrderedDict([ - ("model", ensure_wrapped(test_info.models)), - ("max_tokens", ensure_wrapped(test_info.max_tokens)), - ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), - ("dtype", ensure_wrapped(test_info.dtype)), - ("distributed_executor_backend", - ensure_wrapped(test_info.distributed_executor_backend)), - ]) + iter_kwargs = OrderedDict( + [ + ("model", ensure_wrapped(test_info.models)), + ("max_tokens", ensure_wrapped(test_info.max_tokens)), + ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), + ("dtype", ensure_wrapped(test_info.dtype)), + ( + "distributed_executor_backend", + ensure_wrapped(test_info.distributed_executor_backend), + ), + ] + ) # num_frames is video only if test_type == VLMTestType.VIDEO: - iter_kwargs["num_video_frames"] = ensure_wrapped( - test_info.num_video_frames) + iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: - raise ValueError( - f"Sizes must be set for test type {test_type}") + raise ValueError(f"Sizes must be set for test type {test_type}") iter_kwargs["size_wrapper"] = wrapped_sizes - #Otherwise expand the custom test options instead + # Otherwise expand the custom test options instead elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") @@ -121,8 +138,8 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): def get_wrapped_test_sizes( - test_info: VLMTestInfo, - test_type: VLMTestType) -> tuple[ImageSizeWrapper, ...]: + test_info: VLMTestInfo, test_type: VLMTestType +) -> tuple[ImageSizeWrapper, ...]: """Given a test info which may have size factors or fixed sizes, wrap them and combine them into an iterable, each of which will be used in parameter expansion. @@ -133,18 +150,18 @@ def get_wrapped_test_sizes( """ # If it is an embedding test, we always use the EMBEDDING_SIZE_FACTORS if test_type == VLMTestType.EMBEDDING: - return tuple([ - ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) - for factor in EMBEDDING_SIZE_FACTORS - ]) + return tuple( + [ + ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) + for factor in EMBEDDING_SIZE_FACTORS + ] + ) # Audio and Custom inputs have preprocessed inputs elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() - size_factors = test_info.image_size_factors \ - if test_info.image_size_factors else [] - fixed_sizes = test_info.image_sizes \ - if test_info.image_sizes else [] + size_factors = test_info.image_size_factors if test_info.image_size_factors else [] + fixed_sizes = test_info.image_sizes if test_info.image_sizes else [] wrapped_factors = [ ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) @@ -152,8 +169,7 @@ def get_wrapped_test_sizes( ] wrapped_sizes = [ - ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) - for size in fixed_sizes + ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) for size in fixed_sizes ] return tuple(wrapped_factors + wrapped_sizes) diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index 8c83d8f8a8a2..cefb45227fed 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Core test implementation to be shared across modalities.""" + from typing import Any, Callable, Optional import torch @@ -75,16 +76,18 @@ def run_test( if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) - with vllm_runner(model, - max_model_len=max_model_len, - max_num_seqs=max_num_seqs, - dtype=dtype, - limit_mm_per_prompt=limit_mm_per_prompt, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=enforce_eager, - task=task, - **vllm_runner_kwargs_) as vllm_model: + with vllm_runner( + model, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + dtype=dtype, + limit_mm_per_prompt=limit_mm_per_prompt, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + task=task, + **vllm_runner_kwargs_, + ) as vllm_model: tokenizer = vllm_model.model.get_tokenizer() vllm_kwargs: dict[str, Any] = {} @@ -94,21 +97,19 @@ def run_test( vllm_kwargs["stop"] = stop_str for prompts, image_data, video_data, audio_data in vllm_inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( prompts, max_tokens, num_logprobs=num_logprobs, - **vllm_kwargs_with_mm_data) + **vllm_kwargs_with_mm_data, + ) vllm_outputs_per_mm.append(vllm_output) - hf_model = hf_runner(model, - dtype=dtype, - auto_cls=auto_cls, - model_kwargs=hf_model_kwargs) + hf_model = hf_runner( + model, dtype=dtype, auto_cls=auto_cls, model_kwargs=hf_model_kwargs + ) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: @@ -128,16 +129,15 @@ def run_test( hf_kwargs["stop_strings"] = stop_str for prompts, image_data, video_data, audio_data in inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs_with_mm_data) + **hf_kwargs_with_mm_data, + ) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results @@ -149,8 +149,7 @@ def run_test( second_runner_processor=vllm_output_post_proc, ) - for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, - vllm_outputs_per_mm): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, vllm_outputs_per_mm): # This is usually check_logprobs_close, but it's passed through to # allow things like check_outputs_equal where needed comparator( @@ -170,15 +169,19 @@ def process_runner_outputs( ): """Applies the runner processor(s) to the runner outputs, if any.""" if first_runner_processor is not None: - first_runner_outputs = process_outputs(first_runner_processor, model, - first_runner_outputs) + first_runner_outputs = process_outputs( + first_runner_processor, model, first_runner_outputs + ) if second_runner_processor is not None: - second_runner_outputs = process_outputs(second_runner_processor, model, - second_runner_outputs) + second_runner_outputs = process_outputs( + second_runner_processor, model, second_runner_outputs + ) return first_runner_outputs, second_runner_outputs def process_outputs(output_processor, model, outputs_per_image): """Applies a model specific post-processor function to a runner's output""" - return [[output_processor(res, model) for res in outputs] - for outputs in outputs_per_image] + return [ + [output_processor(res, model) for res in outputs] + for outputs in outputs_per_image + ] diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index c53243b42e38..3886547b8a8b 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" + from io import BytesIO from typing import Callable @@ -8,8 +9,11 @@ from PIL import Image from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs @@ -18,7 +22,7 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): """Builds inputs for multi-image (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -44,7 +48,7 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): stop_sign, rescale_image_size(stop_sign, 0.25), cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) + cherry_blossom.resize((488, 183)), ], cherry_blossom, ] @@ -57,10 +61,11 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): ] -def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], - num_frames: int = 16): +def multi_video_multi_aspect_ratio_inputs( + formatter: Callable[[str], str], num_frames: int = 16 +): """Builds inputs for multi-video (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -84,7 +89,7 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], video, rescale_video_size(video, 0.25), resize_video(video, (183, 488)), - resize_video(video, (488, 183)) + resize_video(video, (488, 183)), ], video, ] @@ -99,7 +104,9 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], def different_patch_input_cases_internvl(): images = [asset.pil_image.resize((896, 896)) for asset in IMAGE_ASSETS] - formatter = lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 + formatter = ( + lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" + ) # noqa: E501 single_img_prompts = [ "\nWhat's the content in the center of the image?", "\nWhat is the season?", @@ -124,8 +131,9 @@ def windows_attention_image_qwen2_5_vl(): question = "Describe the image." img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" - prompt = (f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n<|im_start|>assistant\n" + ) wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) @@ -139,8 +147,9 @@ def video_with_metadata_glm4_1v(): formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] - video_input = [[(rescale_video_size(video_array, scale), metadata)] - for scale in scales] + video_input = [ + [(rescale_video_size(video_array, scale), metadata)] for scale in scales + ] prompts = [formatted_prompt] * len(video_input) return [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index c1a2aa0dcafb..1e005b49d72a 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -4,6 +4,7 @@ for manipulating the input / output of HF & vLLM test runners, which are typically specific to a small subset of models. """ + import types from pathlib import PosixPath from typing import Optional, Union @@ -14,8 +15,13 @@ import regex as re import torch from PIL.Image import Image -from transformers import (AutoConfig, AutoTokenizer, BatchFeature, - GenerationConfig, GenerationMixin) +from transformers import ( + AutoConfig, + AutoTokenizer, + BatchFeature, + GenerationConfig, + GenerationMixin, +) from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs @@ -27,8 +33,7 @@ ####### vLLM output processors functions -def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [blip2 models] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -42,8 +47,7 @@ def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [fuyu models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -53,8 +57,8 @@ def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, def qwen_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -64,8 +68,8 @@ def qwen_vllm_to_hf_output( def qwen2_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -75,8 +79,8 @@ def qwen2_vllm_to_hf_output( def kimiv_vl_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -85,23 +89,25 @@ def kimiv_vl_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs -def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_image_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: config = AutoConfig.from_pretrained(model) mm_token_id = config.image_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) def llava_video_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: config = AutoConfig.from_pretrained(model) mm_token_id = config.video_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) -def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, - mm_token_id: int) -> RunnerOutput: +def _llava_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str, mm_token_id: int +) -> RunnerOutput: """Sanitize vllm output [Llava models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -109,7 +115,8 @@ def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != mm_token_id or output_ids[idx - 1] != mm_token_id ] @@ -128,8 +135,9 @@ def llava_onevision_hf_model_kwargs(model: str) -> dict: return config.to_dict() -def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_onevision_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: """Sanitize vllm output [llava-onevision] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -140,7 +148,8 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != video_token_id or output_ids[idx - 1] != video_token_id ] @@ -151,8 +160,7 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [mantis] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -161,8 +169,7 @@ def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, return output_ids, hf_output_str, out_logprobs -def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -180,8 +187,7 @@ def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -192,7 +198,8 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] @@ -205,46 +212,40 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|end▁of▁sentence|>"): output_str = output_str.split("<|end▁of▁sentence|>")[0] return output_ids, output_str, out_logprobs -def idefics3_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def idefics3_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith(""): output_str = output_str.split("")[0] return output_ids, output_str, out_logprobs -def smolvlm_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def smolvlm_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: # Based on Idefics3 return idefics3_trunc_hf_output(hf_output, model) -def minicpmv_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): output_str = output_str.split("<|eot_id|>")[0] return output_ids, output_str, out_logprobs -def minimax_vl_01_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minimax_vl_01_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith(""): output_str = output_str.split("")[0] return output_ids, output_str, out_logprobs -def ultravox_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def ultravox_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output tokenizer = AutoTokenizer.from_pretrained(model) @@ -262,8 +263,8 @@ def get_llava_embeddings(image_assets: ImageTestAssets): ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, - assets: Union[list[ImageAsset], ImageTestAssets]) -> str: + tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], ImageTestAssets] +) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image in its @@ -313,8 +314,9 @@ def processor(*args, text="", images=None, **kwargs): return BatchFeature(data=inputs, tensor_type="pt") hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language.model.embed_tokens + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language.model.embed_tokens + ) return hf_model @@ -358,11 +360,10 @@ def processor(*args, text="", images=None, **kwargs): assert len(contents) == len(images) return hf_processor.apply_chat_template( - [{ - "role": "user", - "image": image, - "content": content - } for image, content in zip(images, contents)], + [ + {"role": "user", "image": image, "content": content} + for image, content in zip(images, contents) + ], add_generation_prompt=True, tokenize=True, return_dict=True, @@ -370,8 +371,9 @@ def processor(*args, text="", images=None, **kwargs): ) hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.transformer.output_layer + ) return hf_model @@ -388,10 +390,9 @@ def processor(*args, videos=None, **kwargs): else: video_metadata = None - return hf_processor(*args, - videos=videos, - video_metadata=video_metadata, - **kwargs) + return hf_processor( + *args, videos=videos, video_metadata=video_metadata, **kwargs + ) hf_model.processor = processor return hf_model @@ -407,8 +408,9 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.use_msac = self.config.use_msac @@ -416,11 +418,14 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): # yapf: disable from vllm.model_executor.models.h2ovl import ( - IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_h2ovl, + ) # yapf: enable images = [images] if isinstance(images, Image) else images @@ -432,29 +437,26 @@ def __call__(self, text: str, images: Union[Image, list[Image]], max_num=self.max_num, use_thumbnail=self.use_thumbnail, use_msac=self.use_msac, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) + text = text.replace("", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = H2OVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -468,19 +470,23 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.skyworkr1v import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_skyworkr1v) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_skyworkr1v, + ) + images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_skyworkr1v( @@ -489,29 +495,26 @@ def __call__(self, text: str, images: Union[Image, list[Image]], min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) + text = text.replace("", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = SkyworkR1VProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -525,8 +528,9 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch @@ -541,8 +545,13 @@ def __call__( **kwargs, ): from vllm.model_executor.models.internvl import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_internvl, video_to_pixel_values_internvl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_internvl, + video_to_pixel_values_internvl, + ) + images = [images] if isinstance(images, Image) else images videos = [videos] if isinstance(videos, np.ndarray) else videos if images is not None: @@ -553,7 +562,8 @@ def __call__( min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] num_patches_images = [ pixel_value.shape[0] for pixel_value in pixel_values_images @@ -569,7 +579,8 @@ def __call__( min_num=1, max_num=1, use_thumbnail=False, - ) for video in videos + ) + for video in videos ] num_patches_videos = [ pixel_value.shape[0] for pixel_value in pixel_values_videos @@ -581,38 +592,37 @@ def __call__( while ("" in text) or ("