Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def populate_loras(
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.lora_b = sublora.lora_b[(sublora_len *
i):(sublora_len * (i + 1)), :]
sublora.optimize()
subloras.append(sublora)

Expand Down Expand Up @@ -304,9 +304,9 @@ def create_random_embedding_layer():
result = embedding(input_)
after_a = F.embedding(
input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)

Expand Down Expand Up @@ -445,9 +445,9 @@ def create_random_embedding_layer():
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)

Expand Down Expand Up @@ -575,7 +575,7 @@ def _pretest():
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
Expand Down Expand Up @@ -692,9 +692,10 @@ def create_random_linear_replicated_layer():

expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):

lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)

Expand Down Expand Up @@ -817,7 +818,7 @@ def create_random_linear_parallel_layer():
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)

Expand Down Expand Up @@ -965,9 +966,10 @@ class FakeConfig:
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
sublora.scaling)
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
(i + 1)] += (
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
sublora.scaling)
expected_results.append(result)
expected_result = torch.cat(expected_results)

Expand Down
12 changes: 6 additions & 6 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
assert lora.lora_a.shape[0] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
Expand All @@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0], 8], device=device),
)
return LoRAModel(lora_id, 8, loras)

Expand All @@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0] // len(replaced_module_names), 8],
device=device),
)
return LoRAModel(lora_id, 8, loras)
Expand Down
8 changes: 4 additions & 4 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def init_random_lora(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
lora_a=torch.rand([rank, weight.shape[1]],
dtype=weight.dtype,
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
lora_b=torch.rand([weight.shape[0], rank],
dtype=weight.dtype,
device=self._device),
)
Expand Down Expand Up @@ -67,8 +67,8 @@ def init_lora(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
lora_a=torch.rand([rank, input_dim], device="cuda"),
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
Expand Down
10 changes: 5 additions & 5 deletions vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,18 @@ def set_lora(
lora_bias = self.slice_bias(lora_bias)

self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if lora_bias is not None:

self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True)
lora_bias, non_blocking=True)

def apply(self,
x: torch.Tensor,
Expand Down
67 changes: 33 additions & 34 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2
offset = lora_b.shape[-1] // 2
offset = lora_b.shape[0] // 2

left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
shard_size]
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size]
lora_b = torch.cat([left_weight, right_weight], dim=1)
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[start_idx:end_idx, :]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -251,9 +251,8 @@ def slice_lora_b(
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[:,
shard_size * shard_id:shard_size *
(shard_id + 1)]
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
(shard_id + 1), :]
return sliced_lora_b

def slice_bias(
Expand Down Expand Up @@ -285,12 +284,12 @@ def set_lora(
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
lora_a_i.T, non_blocking=True)
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
lora_b_i.T, non_blocking=True)
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
lora_b_i, non_blocking=True)

if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
Expand All @@ -299,7 +298,7 @@ def set_lora(
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index,
0, :lora_bias_i.shape[0]].copy_(
lora_bias_i.T,
lora_bias_i,
non_blocking=True)

@classmethod
Expand Down Expand Up @@ -345,18 +344,18 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
(self.q_shard_id + 1), :]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
lora_b_k = lora_b[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
lora_b_v = lora_b[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -465,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a

def apply(self,
Expand Down Expand Up @@ -508,10 +507,10 @@ def slice_lora_a(
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
lora_a[0][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[0] is not None else None,
lora_a[1][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[1] is not None else None,
]
return lora_a

Expand Down Expand Up @@ -551,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a

def apply(self,
Expand Down Expand Up @@ -589,12 +588,12 @@ def slice_lora_a(
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
lora_a[0][start_idx[0]:start_idx[0] +
shard_size[0], :] if lora_a[0] is not None else None,
lora_a[1][start_idx[1]:start_idx[1] +
shard_size[1], :] if lora_a[1] is not None else None,
lora_a[2][start_idx[2]:start_idx[2] +
shard_size[2], :] if lora_a[2] is not None else None,
]
return lora_a

Expand Down
8 changes: 4 additions & 4 deletions vllm/lora/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def set_lora(
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers/row_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
lora_a = lora_a[:,start_idx:end_idx]
return lora_a

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -122,7 +122,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[ start_idx:end_idx,:]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 6 additions & 4 deletions vllm/lora/layers/vocal_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def set_lora(
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def create_dummy_lora_weights(
embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
lora_a = torch.zeros([rank, input_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
lora_b = torch.zeros([output_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
Expand Down
Loading