diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index f02f38aaf9..c717346aa9 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -582,9 +582,13 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs): from transformers import Qwen2VLForConditionalGeneration kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs) - if model is not None and hasattr(model.model, 'embed_tokens'): - patch_output_clone(model.model.embed_tokens) - patch_output_to_input_device(model.model.embed_tokens) + if model is not None: + if hasattr(model.model, 'embed_tokens'): + embed_tokens = model.model.embed_tokens + else: + embed_tokens = model.model.language_model.embed_tokens + patch_output_clone(embed_tokens) + patch_output_to_input_device(embed_tokens) from qwen_vl_utils import vision_process patch_qwen_vl_utils(vision_process) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 72380c414a..ef1486d5dd 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -10,13 +10,22 @@ def _patch_transformer_engine(): + import transformer_engine try: - from transformer_engine.pytorch.attention import FusedRoPEFunc + from transformer_engine.pytorch.attention import apply_rotary_pos_emb except ImportError: try: - import transformer_engine - transformer_engine.pytorch.attention.FusedRoPEFunc = ( - transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc) + transformer_engine.pytorch.attention.apply_rotary_pos_emb = ( + transformer_engine.pytorch.attention.rope.apply_rotary_pos_emb) + logger.info('Patch apply_rotary_pos_emb successfully applied.') + except (ImportError, AttributeError): + pass + try: + from transformer_engine.pytorch.attention import _SplitAlongDim + except ImportError: + try: + transformer_engine.pytorch.attention._SplitAlongDim = (transformer_engine.pytorch.utils.SplitAlongDim) + logger.info('Patch _SplitAlongDim successfully applied.') except (ImportError, AttributeError): pass diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 103c149cc2..8380553600 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -52,8 +52,9 @@ def __init__(self, def get_train_dataloader(self): train_dataloader = super().get_train_dataloader() - base_dataloader = train_dataloader if isinstance(train_dataloader, - DataLoader) else train_dataloader.base_dataloader + base_dataloader = train_dataloader.base_dataloader if hasattr( + train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader, + DataLoader) else train_dataloader if base_dataloader.worker_init_fn is not None and not isinstance( base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature( base_dataloader.worker_init_fn).parameters: