|
12 | 12 | from math import prod
|
13 | 13 | from pathlib import Path
|
14 | 14 | from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
15 |
| -from transformers import AutoConfig |
| 15 | +from transformers import AutoConfig, AutoTokenizer |
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 |
|
@@ -373,7 +373,22 @@ def set_type(self):
|
373 | 373 | def set_gguf_parameters(self):
|
374 | 374 | logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
375 | 375 | self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
376 |
| - if alora_invocation_tokens := lparams.get("alora_invocation_tokens"): |
| 376 | + alora_invocation_tokens = lparams.get("alora_invocation_tokens") |
| 377 | + invocation_string = lparams.get("invocation_string") |
| 378 | + if invocation_string and not alora_invocation_tokens: |
| 379 | + logger.debug("Tokenizing invocation_string -> alora_invocation_tokens") |
| 380 | + base_model_path_or_id = hparams.get("_name_or_path") |
| 381 | + try: |
| 382 | + tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id) |
| 383 | + except ValueError: |
| 384 | + logger.error("Unable to load tokenizer from %s", base_model_path_or_id) |
| 385 | + raise |
| 386 | + # NOTE: There's an off-by-one with the older aLoRAs where |
| 387 | + # the invocation string includes the "<|start_of_turn|>" |
| 388 | + # token, but the adapters themselves were trained to |
| 389 | + # activate _after_ that first token, so we drop it here. |
| 390 | + alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:] |
| 391 | + if alora_invocation_tokens: |
377 | 392 | logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
|
378 | 393 | self.gguf_writer.add_key_value(
|
379 | 394 | gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
|
|
0 commit comments