diff --git a/README.md b/README.md index f3012749f2..a8f21c1fee 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group: | ## 🎉 News +- 2024.09.24: Support for training and deploying llama3_1-8b-omni. Experience it using `swift infer --model_type llama3_1-8b-omni`. - 2024.09.23: Support for training and deploying pixtral-12b. Experience it using `swift infer --model_type pixtral-12b --dtype fp16`. - 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2064). - 2024.09.07: Support the `Reflection-llama3-70b` model, use by `swift sft/infer --model_type reflection-llama_3_1-70b`. @@ -635,6 +636,8 @@ The complete list of supported models and datasets can be found at [Supported Mo | PaliGemma | Google | English | 3B | chat model | | Florence | Microsoft | English | 0.23B-0.77B | chat model | | Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | English | 8B | chat model | +| Pixtral | [mistralai](https://huggingface.co/mistralai) | English | 12B | chat model | +| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | English | 8B | chat model | #### Diffusion Models diff --git a/README_CN.md b/README_CN.md index cbf75cbcdf..c61484e79c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站: ## 🎉 新闻 +- 2024.09.24: 支持llama3_1-8b-omni的训练与部署. 使用`swift infer --model_type llama3_1-8b-omni`进行体验. - 2024.09.23: 支持pixtral-12b的训练与部署. 使用`swift infer --model_type pixtral-12b --dtype fp16`进行体验. - 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2064). - 2024.09.07: 支持`Reflection-llama3-70b`模型, 使用`swift sft/infer --model_type reflection-llama_3_1-70b`命令即可训练和推理. @@ -628,7 +629,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ | PaliGemma | Google | 英文 | 3B | chat模型 | | Florence | 微软 | 英文 | 0.23B-0.77B | chat模型 | | Idefics3 | [HuggingFaceM4](https://huggingface.co/HuggingFaceM4) | 英文 | 8B | chat模型 | - +| Pixtral | [mistralai](https://huggingface.co/mistralai) | 英文 | 12B | chat模型 | +| Llama3.1-Omni | [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni) | 英文 | 8B | chat模型 | #### 扩散模型 diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 776d50e14a..8c7d64a896 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -438,6 +438,7 @@ |qwen2-vl-72b-instruct-gptq-int8|[qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, auto_gptq>=0.5|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8)| |qwen2-vl-72b-instruct-awq|[qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-AWQ/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, autoawq|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)| |glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|^(transformer.encoder)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|glm4v|✘|✘|✘|✘|transformers>=4.42|vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)| +|llama3_1-8b-omni|[ICTNLP/Llama-3.1-8B-Omni](https://modelscope.cn/models/ICTNLP/Llama-3.1-8B-Omni/summary)|^(model.layers\|model.speech_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llama3_1-omni|✔|✘|✘|✘|whisper, openai-whisper|audio|[ICTNLP/Llama-3.1-8B-Omni](https://huggingface.co/ICTNLP/Llama-3.1-8B-Omni)| |idefics3-8b-llama3|[AI-ModelScope/Idefics3-8B-Llama3](https://modelscope.cn/models/AI-ModelScope/Idefics3-8B-Llama3/summary)|^(model.text_model\|model.connector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|idefics3|✔|✘|✘|✘|transformers>=4.45.0.dev0|vision|[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)| |llava1_5-7b-instruct|[swift/llava-1.5-7b-hf](https://modelscope.cn/models/swift/llava-1.5-7b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)| |llava1_5-13b-instruct|[swift/llava-1.5-13b-hf](https://modelscope.cn/models/swift/llava-1.5-13b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)| diff --git a/docs/source_en/Instruction/Supported-models-datasets.md b/docs/source_en/Instruction/Supported-models-datasets.md index adf8623b59..00cad11b1b 100644 --- a/docs/source_en/Instruction/Supported-models-datasets.md +++ b/docs/source_en/Instruction/Supported-models-datasets.md @@ -438,6 +438,7 @@ The table below introcudes all models supported by SWIFT: |qwen2-vl-72b-instruct-gptq-int8|[qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, auto_gptq>=0.5|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8)| |qwen2-vl-72b-instruct-awq|[qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/qwen/Qwen2-VL-72B-Instruct-AWQ/summary)|^(model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|qwen2-vl|✔|✔|✘|✘|transformers>=4.45.0.dev0, qwen_vl_utils, autoawq|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)| |glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|^(transformer.encoder)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|glm4v|✘|✘|✘|✘|transformers>=4.42|vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)| +|llama3_1-8b-omni|[ICTNLP/Llama-3.1-8B-Omni](https://modelscope.cn/models/ICTNLP/Llama-3.1-8B-Omni/summary)|^(model.layers\|model.speech_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llama3_1-omni|✔|✘|✘|✘|whisper, openai-whisper|audio|[ICTNLP/Llama-3.1-8B-Omni](https://huggingface.co/ICTNLP/Llama-3.1-8B-Omni)| |idefics3-8b-llama3|[AI-ModelScope/Idefics3-8B-Llama3](https://modelscope.cn/models/AI-ModelScope/Idefics3-8B-Llama3/summary)|^(model.text_model\|model.connector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|idefics3|✔|✘|✘|✘|transformers>=4.45.0.dev0|vision|[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)| |llava1_5-7b-instruct|[swift/llava-1.5-7b-hf](https://modelscope.cn/models/swift/llava-1.5-7b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)| |llava1_5-13b-instruct|[swift/llava-1.5-13b-hf](https://modelscope.cn/models/swift/llava-1.5-13b-hf/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|llava1_5|✔|✔|✘|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)| diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 1f5855f6b9..6f4f298ee0 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -1048,14 +1048,16 @@ def __post_init__(self) -> None: if self.eval_steps is None: self.eval_steps = 50 elif self.sft_type == 'full': - if self.freeze_vit: - from swift.utils.module_mapping import MODEL_KEYS_MAPPING - lora_target_modules = model_info.get('lora_target_modules') - vision_tower = None - if isinstance(lora_target_modules, str): - vision_tower = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower - if vision_tower: - self.freeze_parameters += vision_tower + from swift.utils.module_mapping import MODEL_KEYS_MAPPING + lora_target_modules = model_info.get('lora_target_modules') # model_group + model_arch = None + if isinstance(lora_target_modules, str): + model_arch = MODEL_KEYS_MAPPING[lora_target_modules] + if model_arch: + if self.freeze_vit and model_arch.vision_tower: + self.freeze_parameters += model_arch.vision_tower + if model_arch.generator: + self.freeze_parameters += model_arch.generator assert 0 <= self.freeze_parameters_ratio <= 1 assert self.quantization_bit == 0, 'Full parameter fine-tuning does not support quantization.' assert self.dtype != 'fp16', ("Fine-tuning with dtype=='fp16' can lead to NaN issues. " diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 7285966aa8..5feafd6106 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -260,6 +260,8 @@ class ModelType: llama3_1_405b_instruct_awq = 'llama3_1-405b-instruct-awq' llama3_1_405b_instruct_gptq_int4 = 'llama3_1-405b-instruct-gptq-int4' llama3_1_405b_instruct_bnb = 'llama3_1-405b-instruct-bnb' + # omni + llama3_1_8b_omni = 'llama3_1-8b-omni' # reflection reflection_llama_3_1_70b = 'reflection-llama_3_1-70b' # long writer @@ -633,6 +635,7 @@ class LoRATM(NamedTuple): florence = 'florence' idefics3 = 'idefics3' mplug_owl3 = 'mplug_owl3' + llama3_1_omni = 'llama3_1_omni' # default lora target modules for nlp llms. minicpm3 = ['q_a_proj', 'q_b_proj', 'kv_a_proj_with_mqa', 'kv_b_proj'] baichuan = ['W_pack'] @@ -6507,6 +6510,49 @@ def get_model_tokenizer_mplug_owl2(model_dir: str, return model, tokenizer +@register_model( + ModelType.llama3_1_8b_omni, + 'ICTNLP/Llama-3.1-8B-Omni', + LoRATM.llama3_1_omni, + TemplateType.llama3_1_omni, + requires=['whisper', 'openai-whisper'], + support_flash_attn=True, + tags=['multi-modal', 'audio'], + hf_model_id='ICTNLP/Llama-3.1-8B-Omni') +def get_model_tokenizer_omnli(model_dir: str, + torch_dtype: torch.dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + if 'local_repo_path' in kwargs: + local_repo_path = kwargs['local_repo_path'] + else: + local_repo_path = git_clone_github('https://github.com/ictnlp/LLaMA-Omni') + local_repo_path = os.path.join(local_repo_path, 'LLaMA-Omni') + sys.path.append(os.path.join(local_repo_path)) + from omni_speech.model import OmniSpeech2SLlamaForCausalLM, OmniSpeechLlamaForCausalLM + import whisper + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_config.speech_encoder = os.path.join(model_dir, 'large-v3.pt') + if not os.path.exists(model_config.speech_encoder): + whisper.load_model('large-v3', download_root=model_dir) + kwargs['automodel_class'] = OmniSpeech2SLlamaForCausalLM + kwargs['model_config'] = model_config + for key in ['forward', 'generate']: + try: + delattr(OmniSpeech2SLlamaForCausalLM, key) + delattr(OmniSpeechLlamaForCausalLM, key) + except AttributeError: + pass + # not support device_map='auto' + device_map = model_kwargs['device_map'] + model_kwargs['device_map'] = None + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs) + if model: + model.to('cuda:0' if device_map == 'auto' else device_map) + return model, tokenizer + + def fix_transformers_upgrade(module: PreTrainedModel) -> None: # from 4.35, transformers changes its arguments of _set_gradient_checkpointing if version.parse(transformers.__version__) >= version.parse('4.35'): diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index eb5d555943..d23fa91bfc 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -60,6 +60,7 @@ class TemplateType: codegeex4 = 'codegeex4' llama = 'llama' # llama2 llama3 = 'llama3' + llama3_1_omni = 'llama3_1-omni' reflection = 'reflection' longwriter_llama3 = 'longwriter-llama3' # llava-hf @@ -1812,6 +1813,57 @@ class ReflectionTemplate(Llama3TemplateMixin, Template): register_template(TemplateType.reflection, ReflectionTemplate()) register_template(TemplateType.llama3, Llama3Template()) + +class Llama3_1OmniTemplate(Llama3Template): + system = ('You are a helpful language and speech assistant. ' + 'You are able to understand the speech content that the user provides, ' + 'and assist the user with a variety of tasks using natural language.') + + def replace_tag(self, media_type, index, example) -> List[Context]: + assert media_type == 'audio' + return [[-200]] + + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + import whisper + inputs, _ = super()._encode(example) + if len(inputs) == 0: + return inputs, {} + audios = example['audios'] + input_ids = inputs['input_ids'] + labels = inputs['labels'] + inputs['_data'] = {'input_ids': torch.tensor(input_ids)[None]} + if labels is not None: + inputs['_data']['labels'] = torch.tensor(labels)[None] + if audios: + audios = load_batch(audios, whisper.load_audio) + n_mels = get_env_args('n_mels', int, 128) + for i, audio in enumerate(audios): + audio = whisper.pad_or_trim(audio) + audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0) + audios = torch.stack(audios) + inputs['_data'].update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])}) + + return inputs, {} + + def _post_encode(self, model, data: Any) -> Dict[str, Any]: + speech = data.get('speech') + input_ids = data['input_ids'] + labels = data.get('labels') + if speech is not None: + speech_lengths = data['speech_lengths'] + speech = speech.to(model.dtype) + inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels, + speech, speech_lengths)[4:] + else: + inputs_embeds = model.get_model().embed_tokens(input_ids) + res = {'inputs_embeds': inputs_embeds[0]} + if labels is not None: + res['labels'] = labels[0] + return res + + +register_template(TemplateType.llama3_1_omni, Llama3_1OmniTemplate(), lazy_tokenize=True) + OPENBUDDY_DEFAULT_SYSTEM = ( 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n' 'Always answer as helpfully and logically as possible, while being safe. ' diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 5ec51d53f9..481d998860 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -462,7 +462,7 @@ def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None: from swift.utils.module_mapping import MODEL_KEYS_MAPPING from .model import MODEL_MAPPING model_info = MODEL_MAPPING[model_type] - lora_target_modules = model_info.get('lora_target_modules') + lora_target_modules = model_info.get('lora_target_modules') # model_group if not isinstance(lora_target_modules, str): return diff --git a/swift/utils/module_mapping.py b/swift/utils/module_mapping.py index 4d4e4cfacc..2fd8298bb5 100644 --- a/swift/utils/module_mapping.py +++ b/swift/utils/module_mapping.py @@ -46,10 +46,11 @@ class MultiModelKeys(ModelKeys): language_model: Union[List[str], str] = field(default_factory=list) connector: Union[List[str], str] = field(default_factory=list) vision_tower: Union[List[str], str] = field(default_factory=list) + generator: Union[List[str], str] = field(default_factory=list) def __post_init__(self): # compat - for key in ['language_model', 'connector', 'vision_tower']: + for key in ['language_model', 'connector', 'vision_tower', 'generator']: v = getattr(self, key) if isinstance(v, str): setattr(self, key, [v]) @@ -241,7 +242,6 @@ def __post_init__(self): COGVLM_KEYS = MultiModelKeys( language_model='model.layers', - connector=[], vision_tower='model.vision', ) @@ -253,13 +253,11 @@ def __post_init__(self): QWEN_VL_KEYS = MultiModelKeys( language_model='transformer.h', - connector=[], vision_tower='transformer.visual', ) QWEN_AUDIO_KEYS = MultiModelKeys( language_model='transformer.h', - connector=[], vision_tower='transformer.audio', ) @@ -271,13 +269,11 @@ def __post_init__(self): QWEN2_VL_KEYS = MultiModelKeys( language_model='model', - connector=[], vision_tower='visual', ) GLM4V_KEYS = MultiModelKeys( language_model='transformer.encoder', - connector=[], vision_tower='transformer.vision', ) @@ -287,6 +283,13 @@ def __post_init__(self): vision_tower='model.vision_model', ) +LLAMA3_1_OMNI = MultiModelKeys( + language_model='model.layers', + connector='model.speech_projector', + vision_tower='model.speech_encoder', + generator='speech_generator', +) + MODEL_KEYS_MAPPING = OrderedDict([ # MLLM here ('qwen_audio', QWEN_AUDIO_KEYS), @@ -306,6 +309,7 @@ def __post_init__(self): ('florence', FLORENCE_KEYS), ('idefics3', IDEFICS3_KEYS), ('mplug_owl3', MPLUG_OWL3_KEYS), + ('llama3_1_omni', LLAMA3_1_OMNI), # LLM begins here ('llama', LLAMA_KEYS), ('mistral', LLAMA_KEYS),