Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,33 @@
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"

BADREQUEST_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]


@pytest.fixture(scope="module")
def zephyr_lora_files():
Expand Down Expand Up @@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"

# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)

with open(invalid_rank / "adapter_config.json") as f:
@pytest.mark.parametrize("test_name,config_change,expected_error",
BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files, test_name: str,
config_change: dict,
expected_error: str):
# Create test directory
test_dir = tmp_path / test_name

# Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)

print(adapter_config)

# assert False

# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)

with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
# Test loading the adapter
with pytest.raises(openai.BadRequestError, match=expected_error):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
"lora_name": test_name,
"lora_path": str(test_dir)
})


Expand Down
16 changes: 16 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

Expand Down Expand Up @@ -30,11 +31,14 @@ def test_load_checkpoints(
else:
expected_lora_modules.append(module)
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -43,19 +47,25 @@ def test_load_checkpoints(
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -64,10 +74,13 @@ def test_load_checkpoints(
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
max_position_embeddings=4096)
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
".layers.": ".baichuan_layers.",
},
)
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
3 changes: 3 additions & 0 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

Expand All @@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path = get_adapter_absolute_path(lora_name)

# lora loading should work for either absolute path and hugggingface id.
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
59 changes: 2 additions & 57 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
import math
import os
from typing import Dict, List

Expand Down Expand Up @@ -34,68 +32,15 @@
] if current_platform.is_cuda_alike() else ["cpu"])


def test_peft_helper(sql_lora_files):
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)

scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
config = dict(
r=8,
lora_alpha=16,
target_modules=["gate_proj"],
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_dora=True)
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))

lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)

peft_helper = PEFTHelper.from_dict(config)
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_lora_tensors(
1,
tensors,
Expand Down
109 changes: 109 additions & 0 deletions tests/lora/test_peft_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import math
import shutil

import pytest

from vllm.config import LoRAConfig
from vllm.lora.peft_helper import PEFTHelper

ERROR_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]


def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
max_position_embeddings=4096)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
assert peft_helper.context_length == 16384
assert peft_helper.vllm_max_position_embeddings == 4096
assert peft_helper.vllm_long_context_scaling_factor == float(
math.ceil(peft_helper.context_length /
peft_helper.vllm_max_position_embeddings))
# test RSLoRA
rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora"
shutil.copytree(long_context_lora_files_16k_1, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(rslora_config)

# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)

peft_helper = PEFTHelper.from_local_dir(test_dir,
max_position_embeddings=4096)
peft_helper.validate_legal(lora_config)
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3


@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
def test_peft_helper_error(
sql_lora_files,
tmp_path,
test_name: str,
config_change: dict,
expected_error: str,
):
test_dir = tmp_path / test_name
shutil.copytree(sql_lora_files, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)

# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
# Test loading the adapter
with pytest.raises(ValueError, match=expected_error):
PEFTHelper.from_local_dir(
test_dir, max_position_embeddings=4096).validate_legal(lora_config)
1 change: 1 addition & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
Expand Down
Loading
Loading