File tree Expand file tree Collapse file tree 3 files changed +23
-9
lines changed Expand file tree Collapse file tree 3 files changed +23
-9
lines changed Original file line number Diff line number Diff line change @@ -582,9 +582,13 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
582
582
from transformers import Qwen2VLForConditionalGeneration
583
583
kwargs ['automodel_class' ] = kwargs ['automodel_class' ] or Qwen2VLForConditionalGeneration
584
584
model , tokenizer = get_model_tokenizer_multimodal (* args , ** kwargs )
585
- if model is not None and hasattr (model .model , 'embed_tokens' ):
586
- patch_output_clone (model .model .embed_tokens )
587
- patch_output_to_input_device (model .model .embed_tokens )
585
+ if model is not None :
586
+ if hasattr (model .model , 'embed_tokens' ):
587
+ embed_tokens = model .model .embed_tokens
588
+ else :
589
+ embed_tokens = model .model .language_model .embed_tokens
590
+ patch_output_clone (embed_tokens )
591
+ patch_output_to_input_device (embed_tokens )
588
592
589
593
from qwen_vl_utils import vision_process
590
594
patch_qwen_vl_utils (vision_process )
Original file line number Diff line number Diff line change 10
10
11
11
12
12
def _patch_transformer_engine ():
13
+ import transformer_engine
13
14
try :
14
- from transformer_engine .pytorch .attention import FusedRoPEFunc
15
+ from transformer_engine .pytorch .attention import apply_rotary_pos_emb
15
16
except ImportError :
16
17
try :
17
- import transformer_engine
18
- transformer_engine .pytorch .attention .FusedRoPEFunc = (
19
- transformer_engine .pytorch .dot_product_attention .rope .FusedRoPEFunc )
18
+ transformer_engine .pytorch .attention .apply_rotary_pos_emb = (
19
+ transformer_engine .pytorch .attention .rope .apply_rotary_pos_emb )
20
+ logger .info ('Patch apply_rotary_pos_emb successfully applied.' )
21
+ except (ImportError , AttributeError ):
22
+ pass
23
+ try :
24
+ from transformer_engine .pytorch .attention import _SplitAlongDim
25
+ except ImportError :
26
+ try :
27
+ transformer_engine .pytorch .attention ._SplitAlongDim = (transformer_engine .pytorch .utils .SplitAlongDim )
28
+ logger .info ('Patch _SplitAlongDim successfully applied.' )
20
29
except (ImportError , AttributeError ):
21
30
pass
22
31
Original file line number Diff line number Diff line change @@ -52,8 +52,9 @@ def __init__(self,
52
52
53
53
def get_train_dataloader (self ):
54
54
train_dataloader = super ().get_train_dataloader ()
55
- base_dataloader = train_dataloader if isinstance (train_dataloader ,
56
- DataLoader ) else train_dataloader .base_dataloader
55
+ base_dataloader = train_dataloader .base_dataloader if hasattr (
56
+ train_dataloader , 'base_dataloader' ) and isinstance (train_dataloader .base_dataloader ,
57
+ DataLoader ) else train_dataloader
57
58
if base_dataloader .worker_init_fn is not None and not isinstance (
58
59
base_dataloader .worker_init_fn , partial ) and 'num_workers' in inspect .signature (
59
60
base_dataloader .worker_init_fn ).parameters :
You can’t perform that action at this time.
0 commit comments