Skip to content

Commit 92a5011

Browse files
authored
compat transformer_engine update (#4317)
1 parent a8217e2 commit 92a5011

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

swift/llm/model/model/qwen.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,13 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
582582
from transformers import Qwen2VLForConditionalGeneration
583583
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
584584
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
585-
if model is not None and hasattr(model.model, 'embed_tokens'):
586-
patch_output_clone(model.model.embed_tokens)
587-
patch_output_to_input_device(model.model.embed_tokens)
585+
if model is not None:
586+
if hasattr(model.model, 'embed_tokens'):
587+
embed_tokens = model.model.embed_tokens
588+
else:
589+
embed_tokens = model.model.language_model.embed_tokens
590+
patch_output_clone(embed_tokens)
591+
patch_output_to_input_device(embed_tokens)
588592

589593
from qwen_vl_utils import vision_process
590594
patch_qwen_vl_utils(vision_process)

swift/megatron/init.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,22 @@
1010

1111

1212
def _patch_transformer_engine():
13+
import transformer_engine
1314
try:
14-
from transformer_engine.pytorch.attention import FusedRoPEFunc
15+
from transformer_engine.pytorch.attention import apply_rotary_pos_emb
1516
except ImportError:
1617
try:
17-
import transformer_engine
18-
transformer_engine.pytorch.attention.FusedRoPEFunc = (
19-
transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc)
18+
transformer_engine.pytorch.attention.apply_rotary_pos_emb = (
19+
transformer_engine.pytorch.attention.rope.apply_rotary_pos_emb)
20+
logger.info('Patch apply_rotary_pos_emb successfully applied.')
21+
except (ImportError, AttributeError):
22+
pass
23+
try:
24+
from transformer_engine.pytorch.attention import _SplitAlongDim
25+
except ImportError:
26+
try:
27+
transformer_engine.pytorch.attention._SplitAlongDim = (transformer_engine.pytorch.utils.SplitAlongDim)
28+
logger.info('Patch _SplitAlongDim successfully applied.')
2029
except (ImportError, AttributeError):
2130
pass
2231

swift/trainers/rlhf_trainer/rlhf_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def __init__(self,
5252

5353
def get_train_dataloader(self):
5454
train_dataloader = super().get_train_dataloader()
55-
base_dataloader = train_dataloader if isinstance(train_dataloader,
56-
DataLoader) else train_dataloader.base_dataloader
55+
base_dataloader = train_dataloader.base_dataloader if hasattr(
56+
train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader,
57+
DataLoader) else train_dataloader
5758
if base_dataloader.worker_init_fn is not None and not isinstance(
5859
base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature(
5960
base_dataloader.worker_init_fn).parameters:

0 commit comments

Comments
 (0)