Skip to content

Commit 359ede2

Browse files
committed
Support for fine-tuning Pixtral-12B. (#2090)
1 parent b5e9972 commit 359ede2

File tree

10 files changed

+106
-6
lines changed

10 files changed

+106
-6
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ You can contact us and communicate with us by adding our group:
5555
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
5656

5757
## 🎉 News
58-
- 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models.
58+
- 2024.09.23: Support for training and deploying pixtral-12b. Experience it using `swift infer --model_type pixtral-12b --dtype fp16`.
59+
- 🔥2024.09.19: Supports the qwen2.5, qwen2.5-math, and qwen2.5-coder series models. Supports the qwen2-vl-72b series models. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/2064).
5960
- 2024.09.07: Support the `Reflection-llama3-70b` model, use by `swift sft/infer --model_type reflection-llama_3_1-70b`.
6061
- 2024.09.06: Support fine-tuning and inference for mplug-owl3. Best practices can be found [here](https://github.com/modelscope/ms-swift/issues/1969).
6162
- 2024.09.05: Support for the minicpm3-4b model. Experience it using `swift infer --model_type minicpm3-4b`.

README_CN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
5656

5757

5858
## 🎉 新闻
59-
- 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型.
59+
- 2024.09.23: 支持pixtral-12b的训练与部署. 使用`swift infer --model_type pixtral-12b --dtype fp16`进行体验.
60+
- 🔥2024.09.19: 支持qwen2.5、qwen2.5-math、qwen2.5-coder系列模型. 支持qwen2-vl-72b系列模型. 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/2064).
6061
- 2024.09.07: 支持`Reflection-llama3-70b`模型, 使用`swift sft/infer --model_type reflection-llama_3_1-70b`命令即可训练和推理.
6162
- 2024.09.06: 支持mplug-owl3的微调和推理, 最佳实践可以查看[这里](https://github.com/modelscope/ms-swift/issues/1969).
6263
- 2024.09.05: 支持minicpm3-4b模型. 使用`swift infer --model_type minicpm3-4b`进行体验.

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
- `--test_oom_error`: 用于检测训练是否会发生OOM, 默认为`False`. 如果设置为True, 则会将训练集按max_length倒序进行排列, 方便OOM的测试. 该参数一般用于测试, 请谨慎设置.
115115
- `--disable_tqdm`: 是否不启用tqdm, 这在`nohup`启动脚本时很有用. 默认为`False`, 即为启动tqdm.
116116
- `--🔥lazy_tokenize`: 如果设置为False, 则在`trainer.train()`之前提前对所有文本进行预处理. 如果设置为True, 则延迟对文本进行编码, 减少预处理的等待并减少内存占用, 这在处理大数据集时很有用. 默认为`None`, 即我们会根据template的类型进行智能选择, LLM的模型通常设置为False, 多模态的模型通常设置为True(避免图片和音频加载导致过多的内存占用).
117-
- `--🔥preprocess_num_proc`: 在对数据集预处理时(对文本进行tokenize), 使用多进程. 默认为`1`. 与`lazy_tokenize`命令行参数一样, 用于解决预处理速度慢的问题. 但该策略无法减少内存占用, 所以如果当数据集巨大时, 建议使用`lazy_tokenize`. 推荐设置的值: 4, 8. 请注意: 当使用qwen-audio时, 该参数会强制设置为1, 因为qwen-audio的预处理函数中使用了torch的多进程, 会造成不兼容问题.
117+
- `--🔥preprocess_num_proc`: 在对数据集预处理时(对文本进行tokenize), 使用多进程. 默认为`1`. 与`lazy_tokenize`命令行参数一样, 用于解决预处理速度慢的问题. 但该策略无法减少内存占用, 所以如果当数据集巨大时, 建议使用`lazy_tokenize`. 推荐设置的值: 4, 8.
118118
- `--🔥use_flash_attn`: 是否使用flash attn, 默认为`None`. 安装flash_attn的步骤可以查看[https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). 支持flash_attn的模型可以查看[LLM支持的模型](支持的模型和数据集.md#模型).
119119
- `--ignore_args_error`: 是否忽略命令行传参错误抛出的Error, 默认为`False`. 如果需要拷贝代码到notebook中运行, 需要设置成True.
120120
- `--🔥check_model_is_latest`: 检查模型是否是最新, 默认为`True`. 如果你需要断网进行训练, 请将该参数设置为`False`.

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@
492492
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
493493
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|&#x2714;|&#x2714;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
494494
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|&#x2714;|&#x2714;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision, video|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
495+
|pixtral-12b|[AI-ModelScope/pixtral-12b](https://modelscope.cn/models/AI-ModelScope/pixtral-12b/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|pixtral|&#x2718;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.45.0.dev0|vision|[mistral-community/pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b)|
495496
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
496497
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
497498
|mplug-owl3-7b-chat|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728/summary)|^(language_model\|vision2text_model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|mplug_owl3|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.36, icecream|vision|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
- `--test_oom_error`: Used to detect whether training will cause OOM, default is `False`. If set to True, will sort the training set in descending order by max_length, easy for OOM testing. This parameter is generally used for testing, use carefully.
115115
- `--disable_tqdm`: Whether to disable tqdm, useful when launching script with `nohup`. Default is `False`, i.e. enable tqdm.
116116
- `--🔥lazy_tokenize`: If set to False, preprocess all text before `trainer.train()`. If set to True, delay encoding text, reducing preprocessing wait and memory usage, useful when processing large datasets. Default is `None`, i.e. we intelligently choose based on template type, usually set to False for LLM models, set to True for multimodal models (to avoid excessive memory usage from loading images and audio).
117-
- `--🔥preprocess_num_proc`: Use multiprocessing when preprocessing dataset (tokenizing text). Default is `1`. Same as `lazy_tokenize` command line argument, used to solve slow preprocessing issue. But this strategy cannot reduce memory usage, so if dataset is huge, `lazy_tokenize` is recommended. Recommended values: 4, 8. Note: When using qwen-audio, this parameter will be forced to 1, because qwen-audio's preprocessing function uses torch's multiprocessing, which will cause compatibility issues.
117+
- `--🔥preprocess_num_proc`: Use multiprocessing when preprocessing dataset (tokenizing text). Default is `1`. Same as `lazy_tokenize` command line argument, used to solve slow preprocessing issue. But this strategy cannot reduce memory usage, so if dataset is huge, `lazy_tokenize` is recommended. Recommended values: 4, 8.
118118
- `--🔥use_flash_attn`: Whether to use flash attn, default is `None`. Installation steps for flash_attn can be found at [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Models supporting flash_attn can be found in [LLM Supported Models](Supported-models-datasets.md).
119119
- `--ignore_args_error`: Whether to ignore Error thrown by command line parameter errors, default is `False`. Set to True if need to copy code to notebook to run.
120120
- `--🔥check_model_is_latest`: Check if model is latest, default is `True`. Set this to `False` if you need to train offline.

docs/source_en/Instruction/Supported-models-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ The table below introcudes all models supported by SWIFT:
492492
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
493493
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|&#x2714;|&#x2714;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
494494
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|&#x2714;|&#x2714;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision, video|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
495+
|pixtral-12b|[AI-ModelScope/pixtral-12b](https://modelscope.cn/models/AI-ModelScope/pixtral-12b/summary)|^(language_model\|multi_modal_projector)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|pixtral|&#x2718;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.45.0.dev0|vision|[mistral-community/pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b)|
495496
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
496497
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
497498
|mplug-owl3-7b-chat|[iic/mPLUG-Owl3-7B-240728](https://modelscope.cn/models/iic/mPLUG-Owl3-7B-240728/summary)|^(language_model\|vision2text_model)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|mplug_owl3|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers>=4.36, icecream|vision|[mPLUG/mPLUG-Owl3-7B-240728](https://huggingface.co/mPLUG/mPLUG-Owl3-7B-240728)|

swift/llm/utils/argument.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def _prepare_target_modules(self, target_modules) -> Union[List[str], str]:
919919
target_modules.append('DEFAULT')
920920
if 'DEFAULT' in target_modules:
921921
target_modules.remove('DEFAULT')
922-
default_lora_tm = get_default_lora_target_modules(self.model_type)
922+
default_lora_tm = get_default_lora_target_modules(self.model_type) or []
923923
if isinstance(default_lora_tm, str):
924924
return default_lora_tm
925925
target_modules += default_lora_tm

swift/llm/utils/model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ class ModelType:
489489
mixtral_moe_7b_instruct = 'mixtral-moe-7b-instruct'
490490
mixtral_moe_7b_aqlm_2bit_1x16 = 'mixtral-moe-7b-aqlm-2bit-1x16' # aqlm
491491
mixtral_moe_8x22b_v1 = 'mixtral-moe-8x22b-v1'
492+
493+
pixtral_12b = 'pixtral-12b'
492494
# wizardlm
493495
wizardlm2_7b_awq = 'wizardlm2-7b-awq'
494496
wizardlm2_8x22b = 'wizardlm2-8x22b'
@@ -1013,6 +1015,26 @@ def _output_device_map_hook(module, input, output):
10131015
return output.to(input[0].device)
10141016

10151017

1018+
@register_model(
1019+
ModelType.pixtral_12b,
1020+
'AI-ModelScope/pixtral-12b',
1021+
LoRATM.llava,
1022+
TemplateType.pixtral,
1023+
# torch_dtype=torch.float16, # Please do not use bf16.
1024+
requires=['transformers>=4.45.0.dev0'],
1025+
placeholder_tokens=['[IMG]'],
1026+
tags=['multi-modal', 'vision'],
1027+
hf_model_id='mistral-community/pixtral-12b')
1028+
def get_model_tokenizer_pixtral(model_dir: str, *args, **kwargs):
1029+
from transformers import AutoProcessor, LlavaForConditionalGeneration
1030+
processor = AutoProcessor.from_pretrained(model_dir)
1031+
kwargs['automodel_class'] = LlavaForConditionalGeneration
1032+
kwargs['tokenizer'] = processor.tokenizer
1033+
model, tokenizer = get_model_tokenizer_from_repo(model_dir, *args, **kwargs)
1034+
tokenizer.processor = processor
1035+
return model, tokenizer
1036+
1037+
10161038
@register_model(
10171039
ModelType.cogvlm2_video_13b_chat,
10181040
'ZhipuAI/cogvlm2-video-llama3-chat',
@@ -4452,7 +4474,16 @@ def get_model_tokenizer_internvl(model_dir: str,
44524474

44534475
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
44544476
use_flash_attn = kwargs.pop('use_flash_attn', False)
4455-
model_config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
4477+
if hasattr(model_config.llm_config, 'attn_implementation'):
4478+
attr = 'attn_implementation'
4479+
else:
4480+
attr = '_attn_implementation'
4481+
if use_flash_attn:
4482+
setattr(model_config.llm_config, attr, 'flash_attention_2')
4483+
else:
4484+
setattr(model_config.llm_config, attr, 'eager')
4485+
setattr(model_config.llm_config, f'{attr}_internal', None)
4486+
44564487
model_quant_config = getattr(model_config, 'quantization_config', None)
44574488

44584489
use_bnb = False

swift/llm/utils/preprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def new_call_func(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
4141
self.shared_shm_name = shm.name
4242
buffer = shm.buf
4343
self.column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=buffer)
44+
self.column_state[:] = 0
4445
dataset = call_func(self, dataset)
4546
if isinstance(dataset, HfIterableDataset) and dataset.features is None:
4647
features = next(iter(dataset)).keys()

swift/llm/utils/template.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class TemplateType:
8282

8383
idefics3 = 'idefics3'
8484
mistral_nemo = 'mistral-nemo'
85+
pixtral = 'pixtral'
8586
openbuddy = 'openbuddy'
8687
openbuddy2 = 'openbuddy2'
8788
internlm = 'internlm'
@@ -1530,6 +1531,69 @@ class Qwen2VLGenerationTemplate(_Qwen2VLTemplateMixin, DefaultGenerationTemplate
15301531
register_template(TemplateType.qwen2_vl_generation, Qwen2VLGenerationTemplate(), lazy_tokenize=True, is_generation=True)
15311532

15321533

1534+
def _gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
1535+
# List[Tensor] -> List[Tensor]
1536+
res = []
1537+
for b in batch:
1538+
if b.get(attr_name) is not None:
1539+
res += b.pop(attr_name)
1540+
return res
1541+
1542+
1543+
class PixtralTemplate(Template):
1544+
1545+
def __init__(self):
1546+
super().__init__(['<s>{{SYSTEM}}'], ['[INST]{{QUERY}}[/INST]'], ['</s>'], ['</s>'], None)
1547+
1548+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
1549+
example: Dict[str, Any]) -> List[Context]:
1550+
return ['[IMG]']
1551+
1552+
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1553+
inputs, _ = super()._encode(example)
1554+
if len(inputs) == 0:
1555+
return inputs, {}
1556+
processor = self.tokenizer.processor
1557+
images = example['images']
1558+
input_ids = inputs['input_ids']
1559+
labels = inputs['labels']
1560+
idx_list = _findall(input_ids, 10)
1561+
if idx_list:
1562+
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
1563+
inputs['pixel_values'] = image_inputs['pixel_values'][0]
1564+
image_sizes = image_inputs['image_sizes'][0]
1565+
added_tokens_len = 0
1566+
for idx, image_size in zip(idx_list, image_sizes):
1567+
height, width = image_size
1568+
num_height_tokens = height // processor.patch_size
1569+
num_width_tokens = width // processor.patch_size
1570+
replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
1571+
num_height_tokens - 1)
1572+
replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
1573+
# Flatten list
1574+
replace_str = ''.join(replace_tokens)
1575+
img_tokens: List[int] = self.tokenizer.encode(replace_str, add_special_tokens=False)
1576+
input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
1577+
if labels is not None:
1578+
labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
1579+
+ 1:]
1580+
added_tokens_len += len(img_tokens) - 1
1581+
inputs['input_ids'] = input_ids
1582+
inputs['labels'] = labels
1583+
1584+
return inputs, {}
1585+
1586+
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
1587+
pixel_values = _gather_list(batch, 'pixel_values')
1588+
res = super().data_collator(batch, padding_to)
1589+
if pixel_values:
1590+
res['pixel_values'] = pixel_values
1591+
return res
1592+
1593+
1594+
register_template(TemplateType.pixtral, PixtralTemplate(), lazy_tokenize=True)
1595+
1596+
15331597
class YiCoderTemplate(ChatmlTemplate):
15341598
system = 'You are a helpful assistant.'
15351599

0 commit comments

Comments
 (0)