|
18 | 18 | from transformers import PreTrainedTokenizerBase, StoppingCriteria
|
19 | 19 | from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
20 | 20 | from transformers.integrations import is_deepspeed_zero3_enabled
|
| 21 | +from transformers.utils import strtobool |
21 | 22 |
|
22 | 23 | from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt
|
23 | 24 | from swift.torchacc_utils import pad_and_split_batch
|
@@ -179,6 +180,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> b
|
179 | 180 | return False
|
180 | 181 |
|
181 | 182 |
|
| 183 | +def is_deepspeed_enabled(): |
| 184 | + return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', 'False')) |
| 185 | + |
| 186 | + |
182 | 187 | class Template:
|
183 | 188 | """A template class for all supported models.
|
184 | 189 |
|
@@ -1504,8 +1509,29 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
|
1504 | 1509 |
|
1505 | 1510 | inputs['input_ids'] = input_ids
|
1506 | 1511 | inputs['labels'] = labels
|
| 1512 | + inputs['_data'] = {'plain_text': not images and not videos, 'input_ids': torch.tensor(input_ids)[None]} |
1507 | 1513 | return inputs, {}
|
1508 | 1514 |
|
| 1515 | + def _post_encode(self, model, data: Any) -> Dict[str, Any]: |
| 1516 | + plain_text = data.pop('plain_text', False) |
| 1517 | + if is_deepspeed_enabled() and plain_text: |
| 1518 | + from PIL import Image |
| 1519 | + images = [Image.new('RGB', (32, 32), (0, 0, 0))] |
| 1520 | + processor = self.tokenizer.processor |
| 1521 | + media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt') |
| 1522 | + input_ids = data['input_ids'] |
| 1523 | + device = input_ids.device |
| 1524 | + pixel_values = media_inputs['pixel_values'].to(device) |
| 1525 | + _model = model.model |
| 1526 | + if not hasattr(_model, 'embed_tokens'): |
| 1527 | + _model = _model.model # LoRA |
| 1528 | + inputs_embeds = _model.embed_tokens(input_ids) |
| 1529 | + pixel_values = pixel_values.type(model.visual.get_dtype()) |
| 1530 | + image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) |
| 1531 | + inputs_embeds += image_embeds.mean() * 0. |
| 1532 | + return {'inputs_embeds': inputs_embeds[0]} |
| 1533 | + return {} |
| 1534 | + |
1509 | 1535 | def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
|
1510 | 1536 | res = super().data_collator(batch, padding_to)
|
1511 | 1537 | for media_type in ['image', 'video']:
|
@@ -2150,7 +2176,7 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
|
2150 | 2176 | vit_embeds = model.extract_feature(pixel_values).to(device=device)
|
2151 | 2177 | selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
|
2152 | 2178 | inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
|
2153 |
| - elif is_deepspeed_zero3_enabled(): |
| 2179 | + elif is_deepspeed_enabled(): |
2154 | 2180 | dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
|
2155 | 2181 | vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
|
2156 | 2182 | inputs_embeds += vit_embeds.mean() * 0.
|
|
0 commit comments