Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions swift/llm/model/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,13 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
from transformers import Qwen2VLForConditionalGeneration
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
if model is not None and hasattr(model.model, 'embed_tokens'):
patch_output_clone(model.model.embed_tokens)
patch_output_to_input_device(model.model.embed_tokens)
if model is not None:
if hasattr(model.model, 'embed_tokens'):
embed_tokens = model.model.embed_tokens
else:
embed_tokens = model.model.language_model.embed_tokens
patch_output_clone(embed_tokens)
patch_output_to_input_device(embed_tokens)

from qwen_vl_utils import vision_process
patch_qwen_vl_utils(vision_process)
Expand Down
17 changes: 13 additions & 4 deletions swift/megatron/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@


def _patch_transformer_engine():
import transformer_engine
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
from transformer_engine.pytorch.attention import apply_rotary_pos_emb
except ImportError:
try:
import transformer_engine
transformer_engine.pytorch.attention.FusedRoPEFunc = (
transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc)
transformer_engine.pytorch.attention.apply_rotary_pos_emb = (
transformer_engine.pytorch.attention.rope.apply_rotary_pos_emb)
logger.info('Patch apply_rotary_pos_emb successfully applied.')
except (ImportError, AttributeError):
pass
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
except ImportError:
try:
transformer_engine.pytorch.attention._SplitAlongDim = (transformer_engine.pytorch.utils.SplitAlongDim)
logger.info('Patch _SplitAlongDim successfully applied.')
except (ImportError, AttributeError):
pass

Expand Down
5 changes: 3 additions & 2 deletions swift/trainers/rlhf_trainer/rlhf_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def __init__(self,

def get_train_dataloader(self):
train_dataloader = super().get_train_dataloader()
base_dataloader = train_dataloader if isinstance(train_dataloader,
DataLoader) else train_dataloader.base_dataloader
base_dataloader = train_dataloader.base_dataloader if hasattr(
train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader,
DataLoader) else train_dataloader
if base_dataloader.worker_init_fn is not None and not isinstance(
base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature(
base_dataloader.worker_init_fn).parameters:
Expand Down
Loading