diff --git a/convert.py b/convert.py index 817cb66123a8f..15095cc47f389 100755 --- a/convert.py +++ b/convert.py @@ -1293,7 +1293,7 @@ def load_some_model(path: Path) -> ModelPlus: class VocabFactory: - _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"} + _FILES = {"spm": ["tokenizer.model"], "bpe": ["vocab.json", "tokenizer.json"], "hfft": ["tokenizer.json"]} def __init__(self, path: Path): self.path = path @@ -1301,11 +1301,12 @@ def __init__(self, path: Path): print(f"Found vocab files: {self.file_paths}") def _detect_files(self) -> dict[str, Path | None]: - def locate(file: str) -> Path | None: - if (path := self.path / file).exists(): - return path - if (path := self.path.parent / file).exists(): - return path + def locate(files: list[str]) -> Path | None: + for file in files: + if (path := self.path / file).exists(): + return path + if (path := self.path.parent / file).exists(): + return path return None return {vt: locate(f) for vt, f in self._FILES.items()}