Skip to content

Commit 8e96f12

Browse files
authored
Fix qwen2-vl zero2/3 (#2114)
1 parent c3ac7c1 commit 8e96f12

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

swift/llm/utils/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,6 +3649,7 @@ def _read_from_stream(container: 'av.container.Container', start_offset: float,
36493649
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
36503650
tokenizer.processor = processor
36513651
if model is not None:
3652+
model.model.embed_tokens.register_forward_hook(_clone_hook)
36523653
model.model.embed_tokens.register_forward_hook(_output_device_map_hook)
36533654
return model, tokenizer
36543655

swift/llm/utils/template.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from transformers import PreTrainedTokenizerBase, StoppingCriteria
1919
from transformers.dynamic_module_utils import get_class_from_dynamic_module
2020
from transformers.integrations import is_deepspeed_zero3_enabled
21+
from transformers.utils import strtobool
2122

2223
from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt
2324
from swift.torchacc_utils import pad_and_split_batch
@@ -179,6 +180,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> b
179180
return False
180181

181182

183+
def is_deepspeed_enabled():
184+
return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', 'False'))
185+
186+
182187
class Template:
183188
"""A template class for all supported models.
184189
@@ -1504,8 +1509,29 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
15041509

15051510
inputs['input_ids'] = input_ids
15061511
inputs['labels'] = labels
1512+
inputs['_data'] = {'plain_text': not images and not videos, 'input_ids': torch.tensor(input_ids)[None]}
15071513
return inputs, {}
15081514

1515+
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
1516+
plain_text = data.pop('plain_text', False)
1517+
if is_deepspeed_enabled() and plain_text:
1518+
from PIL import Image
1519+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
1520+
processor = self.tokenizer.processor
1521+
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
1522+
input_ids = data['input_ids']
1523+
device = input_ids.device
1524+
pixel_values = media_inputs['pixel_values'].to(device)
1525+
_model = model.model
1526+
if not hasattr(_model, 'embed_tokens'):
1527+
_model = _model.model # LoRA
1528+
inputs_embeds = _model.embed_tokens(input_ids)
1529+
pixel_values = pixel_values.type(model.visual.get_dtype())
1530+
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
1531+
inputs_embeds += image_embeds.mean() * 0.
1532+
return {'inputs_embeds': inputs_embeds[0]}
1533+
return {}
1534+
15091535
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
15101536
res = super().data_collator(batch, padding_to)
15111537
for media_type in ['image', 'video']:
@@ -2150,7 +2176,7 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
21502176
vit_embeds = model.extract_feature(pixel_values).to(device=device)
21512177
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
21522178
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
2153-
elif is_deepspeed_zero3_enabled():
2179+
elif is_deepspeed_enabled():
21542180
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
21552181
vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
21562182
inputs_embeds += vit_embeds.mean() * 0.

0 commit comments

Comments
 (0)