diff --git a/README.md b/README.md
index f3012749f2..a8f21c1fee 100644
--- a/README.md
+++ b/README.md
@@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group:
|
## 🎉 News
+- 2024.09.24: Support for training and deploying llama3_1-8b-omni. Experience it using `swift infer --model_type llama3_1-8b-omni`.
- 2024.09.23: Support for training and deploying pixtral-12b. Experience it using `swift infer --model_type pixtral-12b --dtype fp16`.
- 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2064).
- 2024.09.07: Support the `Reflection-llama3-70b` model, use by `swift sft/infer --model_type reflection-llama_3_1-70b`.
@@ -635,6 +636,8 @@ The complete list of supported models and datasets can be found at [Supported Mo
| PaliGemma | Google | English | 3B | chat model |
| Florence | Microsoft | English | 0.23B-0.77B | chat model |
| Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | English | 8B | chat model |
+| Pixtral | [mistralai](https://huggingface.co/mistralai) | English | 12B | chat model |
+| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | English | 8B | chat model |
#### Diffusion Models
diff --git a/README_CN.md b/README_CN.md
index cbf75cbcdf..c61484e79c 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
## 🎉 新闻
+- 2024.09.24: 支持llama3_1-8b-omni的训练与部署. 使用`swift infer --model_type llama3_1-8b-omni`进行体验.
- 2024.09.23: 支持pixtral-12b的训练与部署. 使用`swift infer --model_type pixtral-12b --dtype fp16`进行体验.
- 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2064).
- 2024.09.07: 支持`Reflection-llama3-70b`模型, 使用`swift sft/infer --model_type reflection-llama_3_1-70b`命令即可训练和推理.
@@ -628,7 +629,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
| PaliGemma | Google | 英文 | 3B | chat模型 |
| Florence | 微软 | 英文 | 0.23B-0.77B | chat模型 |
| Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | 英文 | 8B | chat模型 |
-
+| Pixtral | [mistralai](https://huggingface.co/mistralai) | 英文 | 12B | chat模型 |
+| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | 英文 | 8B | chat模型 |
#### 扩散模型
diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
index 776d50e14a..8c7d64a896 100644
--- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
+++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
@@ -438,6 +438,7 @@
|qwen2-vl-72b-instruct-gptq-int8|[qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, auto_gptq>=0.5|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8)|
|qwen2-vl-72b-instruct-awq|[qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-AWQ/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, autoawq|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)|
|glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|^(transformer.encoder)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|glm4v|✘|✘|✘|✘|transformers>=4.42|vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
+|llama3_1-8b-omni|[ICTNLP/Llama-3.1-8B-Omni](https://modelscope.cn/models/ICTNLP/Llama-3.1-8B-Omni/summary)|^(model.layers\|model.speech_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llama3_1-omni|✔|✘|✘|✘|whisper, openai-whisper|audio|[ICTNLP/Llama-3.1-8B-Omni](https://huggingface.co/ICTNLP/Llama-3.1-8B-Omni)|
|idefics3-8b-llama3|[AI-ModelScope/Idefics3-8B-Llama3](https://modelscope.cn/models/AI-ModelScope/Idefics3-8B-Llama3/summary)|^(model.text_model\|model.connector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|idefics3|✔|✘|✘|✘|transformers>=4.45.0.dev0|vision|[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)|
|llava1_5-7b-instruct|[swift/llava-1.5-7b-hf](https://modelscope.cn/models/swift/llava-1.5-7b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)|
|llava1_5-13b-instruct|[swift/llava-1.5-13b-hf](https://modelscope.cn/models/swift/llava-1.5-13b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)|
diff --git a/docs/source_en/Instruction/Supported-models-datasets.md b/docs/source_en/Instruction/Supported-models-datasets.md
index adf8623b59..00cad11b1b 100644
--- a/docs/source_en/Instruction/Supported-models-datasets.md
+++ b/docs/source_en/Instruction/Supported-models-datasets.md
@@ -438,6 +438,7 @@ The table below introcudes all models supported by SWIFT:
|qwen2-vl-72b-instruct-gptq-int8|[qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, auto_gptq>=0.5|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8)|
|qwen2-vl-72b-instruct-awq|[qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-AWQ/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, autoawq|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)|
|glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|^(transformer.encoder)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|glm4v|✘|✘|✘|✘|transformers>=4.42|vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)|
+|llama3_1-8b-omni|[ICTNLP/Llama-3.1-8B-Omni](https://modelscope.cn/models/ICTNLP/Llama-3.1-8B-Omni/summary)|^(model.layers\|model.speech_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llama3_1-omni|✔|✘|✘|✘|whisper, openai-whisper|audio|[ICTNLP/Llama-3.1-8B-Omni](https://huggingface.co/ICTNLP/Llama-3.1-8B-Omni)|
|idefics3-8b-llama3|[AI-ModelScope/Idefics3-8B-Llama3](https://modelscope.cn/models/AI-ModelScope/Idefics3-8B-Llama3/summary)|^(model.text_model\|model.connector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|idefics3|✔|✘|✘|✘|transformers>=4.45.0.dev0|vision|[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)|
|llava1_5-7b-instruct|[swift/llava-1.5-7b-hf](https://modelscope.cn/models/swift/llava-1.5-7b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)|
|llava1_5-13b-instruct|[swift/llava-1.5-13b-hf](https://modelscope.cn/models/swift/llava-1.5-13b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)|
diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py
index 1f5855f6b9..6f4f298ee0 100644
--- a/swift/llm/utils/argument.py
+++ b/swift/llm/utils/argument.py
@@ -1048,14 +1048,16 @@ def __post_init__(self) -> None:
if self.eval_steps is None:
self.eval_steps = 50
elif self.sft_type == 'full':
- if self.freeze_vit:
- from swift.utils.module_mapping import MODEL_KEYS_MAPPING
- lora_target_modules = model_info.get('lora_target_modules')
- vision_tower = None
- if isinstance(lora_target_modules, str):
- vision_tower = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower
- if vision_tower:
- self.freeze_parameters += vision_tower
+ from swift.utils.module_mapping import MODEL_KEYS_MAPPING
+ lora_target_modules = model_info.get('lora_target_modules') # model_group
+ model_arch = None
+ if isinstance(lora_target_modules, str):
+ model_arch = MODEL_KEYS_MAPPING[lora_target_modules]
+ if model_arch:
+ if self.freeze_vit and model_arch.vision_tower:
+ self.freeze_parameters += model_arch.vision_tower
+ if model_arch.generator:
+ self.freeze_parameters += model_arch.generator
assert 0 <= self.freeze_parameters_ratio <= 1
assert self.quantization_bit == 0, 'Full parameter fine-tuning does not support quantization.'
assert self.dtype != 'fp16', ("Fine-tuning with dtype=='fp16' can lead to NaN issues. "
diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py
index 7285966aa8..5feafd6106 100644
--- a/swift/llm/utils/model.py
+++ b/swift/llm/utils/model.py
@@ -260,6 +260,8 @@ class ModelType:
llama3_1_405b_instruct_awq = 'llama3_1-405b-instruct-awq'
llama3_1_405b_instruct_gptq_int4 = 'llama3_1-405b-instruct-gptq-int4'
llama3_1_405b_instruct_bnb = 'llama3_1-405b-instruct-bnb'
+ # omni
+ llama3_1_8b_omni = 'llama3_1-8b-omni'
# reflection
reflection_llama_3_1_70b = 'reflection-llama_3_1-70b'
# long writer
@@ -633,6 +635,7 @@ class LoRATM(NamedTuple):
florence = 'florence'
idefics3 = 'idefics3'
mplug_owl3 = 'mplug_owl3'
+ llama3_1_omni = 'llama3_1_omni'
# default lora target modules for nlp llms.
minicpm3 = ['q_a_proj', 'q_b_proj', 'kv_a_proj_with_mqa', 'kv_b_proj']
baichuan = ['W_pack']
@@ -6507,6 +6510,49 @@ def get_model_tokenizer_mplug_owl2(model_dir: str,
return model, tokenizer
+@register_model(
+ ModelType.llama3_1_8b_omni,
+ 'ICTNLP/Llama-3.1-8B-Omni',
+ LoRATM.llama3_1_omni,
+ TemplateType.llama3_1_omni,
+ requires=['whisper', 'openai-whisper'],
+ support_flash_attn=True,
+ tags=['multi-modal', 'audio'],
+ hf_model_id='ICTNLP/Llama-3.1-8B-Omni')
+def get_model_tokenizer_omnli(model_dir: str,
+ torch_dtype: torch.dtype,
+ model_kwargs: Dict[str, Any],
+ load_model: bool = True,
+ **kwargs):
+ if 'local_repo_path' in kwargs:
+ local_repo_path = kwargs['local_repo_path']
+ else:
+ local_repo_path = git_clone_github('https://github.com/ictnlp/LLaMA-Omni')
+ local_repo_path = os.path.join(local_repo_path, 'LLaMA-Omni')
+ sys.path.append(os.path.join(local_repo_path))
+ from omni_speech.model import OmniSpeech2SLlamaForCausalLM, OmniSpeechLlamaForCausalLM
+ import whisper
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
+ model_config.speech_encoder = os.path.join(model_dir, 'large-v3.pt')
+ if not os.path.exists(model_config.speech_encoder):
+ whisper.load_model('large-v3', download_root=model_dir)
+ kwargs['automodel_class'] = OmniSpeech2SLlamaForCausalLM
+ kwargs['model_config'] = model_config
+ for key in ['forward', 'generate']:
+ try:
+ delattr(OmniSpeech2SLlamaForCausalLM, key)
+ delattr(OmniSpeechLlamaForCausalLM, key)
+ except AttributeError:
+ pass
+ # not support device_map='auto'
+ device_map = model_kwargs['device_map']
+ model_kwargs['device_map'] = None
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
+ if model:
+ model.to('cuda:0' if device_map == 'auto' else device_map)
+ return model, tokenizer
+
+
def fix_transformers_upgrade(module: PreTrainedModel) -> None:
# from 4.35, transformers changes its arguments of _set_gradient_checkpointing
if version.parse(transformers.__version__) >= version.parse('4.35'):
diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py
index eb5d555943..d23fa91bfc 100644
--- a/swift/llm/utils/template.py
+++ b/swift/llm/utils/template.py
@@ -60,6 +60,7 @@ class TemplateType:
codegeex4 = 'codegeex4'
llama = 'llama' # llama2
llama3 = 'llama3'
+ llama3_1_omni = 'llama3_1-omni'
reflection = 'reflection'
longwriter_llama3 = 'longwriter-llama3'
# llava-hf
@@ -1812,6 +1813,57 @@ class ReflectionTemplate(Llama3TemplateMixin, Template):
register_template(TemplateType.reflection, ReflectionTemplate())
register_template(TemplateType.llama3, Llama3Template())
+
+class Llama3_1OmniTemplate(Llama3Template):
+ system = ('You are a helpful language and speech assistant. '
+ 'You are able to understand the speech content that the user provides, '
+ 'and assist the user with a variety of tasks using natural language.')
+
+ def replace_tag(self, media_type, index, example) -> List[Context]:
+ assert media_type == 'audio'
+ return [[-200]]
+
+ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ import whisper
+ inputs, _ = super()._encode(example)
+ if len(inputs) == 0:
+ return inputs, {}
+ audios = example['audios']
+ input_ids = inputs['input_ids']
+ labels = inputs['labels']
+ inputs['_data'] = {'input_ids': torch.tensor(input_ids)[None]}
+ if labels is not None:
+ inputs['_data']['labels'] = torch.tensor(labels)[None]
+ if audios:
+ audios = load_batch(audios, whisper.load_audio)
+ n_mels = get_env_args('n_mels', int, 128)
+ for i, audio in enumerate(audios):
+ audio = whisper.pad_or_trim(audio)
+ audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0)
+ audios = torch.stack(audios)
+ inputs['_data'].update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])})
+
+ return inputs, {}
+
+ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
+ speech = data.get('speech')
+ input_ids = data['input_ids']
+ labels = data.get('labels')
+ if speech is not None:
+ speech_lengths = data['speech_lengths']
+ speech = speech.to(model.dtype)
+ inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels,
+ speech, speech_lengths)[4:]
+ else:
+ inputs_embeds = model.get_model().embed_tokens(input_ids)
+ res = {'inputs_embeds': inputs_embeds[0]}
+ if labels is not None:
+ res['labels'] = labels[0]
+ return res
+
+
+register_template(TemplateType.llama3_1_omni, Llama3_1OmniTemplate(), lazy_tokenize=True)
+
OPENBUDDY_DEFAULT_SYSTEM = (
'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
'Always answer as helpfully and logically as possible, while being safe. '
diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py
index 5ec51d53f9..481d998860 100644
--- a/swift/llm/utils/utils.py
+++ b/swift/llm/utils/utils.py
@@ -462,7 +462,7 @@ def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None:
from swift.utils.module_mapping import MODEL_KEYS_MAPPING
from .model import MODEL_MAPPING
model_info = MODEL_MAPPING[model_type]
- lora_target_modules = model_info.get('lora_target_modules')
+ lora_target_modules = model_info.get('lora_target_modules') # model_group
if not isinstance(lora_target_modules, str):
return
diff --git a/swift/utils/module_mapping.py b/swift/utils/module_mapping.py
index 4d4e4cfacc..2fd8298bb5 100644
--- a/swift/utils/module_mapping.py
+++ b/swift/utils/module_mapping.py
@@ -46,10 +46,11 @@ class MultiModelKeys(ModelKeys):
language_model: Union[List[str], str] = field(default_factory=list)
connector: Union[List[str], str] = field(default_factory=list)
vision_tower: Union[List[str], str] = field(default_factory=list)
+ generator: Union[List[str], str] = field(default_factory=list)
def __post_init__(self):
# compat
- for key in ['language_model', 'connector', 'vision_tower']:
+ for key in ['language_model', 'connector', 'vision_tower', 'generator']:
v = getattr(self, key)
if isinstance(v, str):
setattr(self, key, [v])
@@ -241,7 +242,6 @@ def __post_init__(self):
COGVLM_KEYS = MultiModelKeys(
language_model='model.layers',
- connector=[],
vision_tower='model.vision',
)
@@ -253,13 +253,11 @@ def __post_init__(self):
QWEN_VL_KEYS = MultiModelKeys(
language_model='transformer.h',
- connector=[],
vision_tower='transformer.visual',
)
QWEN_AUDIO_KEYS = MultiModelKeys(
language_model='transformer.h',
- connector=[],
vision_tower='transformer.audio',
)
@@ -271,13 +269,11 @@ def __post_init__(self):
QWEN2_VL_KEYS = MultiModelKeys(
language_model='model',
- connector=[],
vision_tower='visual',
)
GLM4V_KEYS = MultiModelKeys(
language_model='transformer.encoder',
- connector=[],
vision_tower='transformer.vision',
)
@@ -287,6 +283,13 @@ def __post_init__(self):
vision_tower='model.vision_model',
)
+LLAMA3_1_OMNI = MultiModelKeys(
+ language_model='model.layers',
+ connector='model.speech_projector',
+ vision_tower='model.speech_encoder',
+ generator='speech_generator',
+)
+
MODEL_KEYS_MAPPING = OrderedDict([
# MLLM here
('qwen_audio', QWEN_AUDIO_KEYS),
@@ -306,6 +309,7 @@ def __post_init__(self):
('florence', FLORENCE_KEYS),
('idefics3', IDEFICS3_KEYS),
('mplug_owl3', MPLUG_OWL3_KEYS),
+ ('llama3_1_omni', LLAMA3_1_OMNI),
# LLM begins here
('llama', LLAMA_KEYS),
('mistral', LLAMA_KEYS),