Skip to content

Commit 5479249

Browse files
authored
[megatron] Add extra args and provider support for easily customize megatron (#4240)
1 parent 67271bb commit 5479249

File tree

5 files changed

+37
-4
lines changed

5 files changed

+37
-4
lines changed

swift/megatron/argument/megatron_args.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import asdict, dataclass, field
55
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
66

7+
import json
78
import megatron.core
89
import torch
910
from packaging import version
@@ -167,6 +168,9 @@ class MegatronArguments(ExtraMegatronArguments):
167168
num_workers: int = 4
168169
no_create_attention_mask_in_dataloader: bool = True
169170

171+
# extra_args for megatron
172+
extra_megatron_kwargs: Optional[str] = None
173+
170174
def _set_default(self):
171175
if self.num_query_groups is None:
172176
self.num_query_groups = 1
@@ -231,6 +235,17 @@ def __post_init__(self):
231235

232236
self.tensorboard_dir = to_abspath(self.tensorboard_dir)
233237

238+
try:
239+
if self.extra_megatron_kwargs is None:
240+
self.extra_megatron_kwargs = {}
241+
elif isinstance(self.extra_megatron_kwargs, str):
242+
self.extra_megatron_kwargs = json.loads(self.extra_megatron_kwargs)
243+
elif isinstance(self.extra_megatron_kwargs, dict):
244+
# For loading from config file
245+
self.extra_megatron_kwargs = self.extra_megatron_kwargs
246+
except json.JSONDecodeError:
247+
raise ValueError('extra_megatron_kwargs should be a valid json string')
248+
234249
def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
235250
new_args = []
236251
args_dict = asdict(self)
@@ -241,6 +256,15 @@ def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
241256
if k not in MegatronArguments.__annotations__:
242257
extra_args[k] = value
243258
continue
259+
if k == 'extra_megatron_kwargs':
260+
if isinstance(value, str):
261+
value = json.loads(value)
262+
if not isinstance(value, dict):
263+
raise ValueError(f'extra_megatron_kwargs should be a dict, but got {type(value)}')
264+
for sub_key, sub_value in value.items():
265+
new_args.append(f"--{sub_key.replace('_', '-')}")
266+
new_args.append(str(sub_value))
267+
continue
244268
if value is None or value is False:
245269
continue
246270
new_args.append(f"--{k.replace('_', '-')}")

swift/megatron/model/register.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from argparse import ArgumentParser
23
from dataclasses import dataclass, field
34
from typing import Any, Callable, Dict, List, Optional
45

@@ -20,6 +21,8 @@ class MegatronModelMeta:
2021
convert_mcore2hf: Callable[[nn.Module, nn.Module], None]
2122
convert_hf2mcore: Callable[[nn.Module, nn.Module], None]
2223

24+
extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None
25+
2326

2427
def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False):
2528
megatron_model_type = megatron_model_meta.megatron_model_type

swift/megatron/train/sft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ def run(self):
4747
logger.info(f'The logging file will be saved in: {logging_path}')
4848
try:
4949
with patch_megatron_data_collator(data_collator):
50+
extra_args_provider = args.megatron_model_meta.extra_args_provider
5051
pretrain(
5152
datasets_provider,
5253
args.megatron_model_meta.model_provider,
5354
ModelType.encoder_or_decoder,
5455
forward_step,
56+
extra_args_provider=extra_args_provider,
5557
args_defaults=args.extra_args)
5658
finally:
5759
# Visualization

swift/megatron/utils/convert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def convert_hf2mcore(args: ExportArguments) -> None:
7474
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
7575
patch_megatron_tokenizer(processor)
7676
extra_args = megatron_args.parse_to_megatron()
77-
initialize_megatron(args_defaults=extra_args)
77+
extra_args_provider = megatron_model_meta.extra_args_provider
78+
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args)
7879

7980
mg_model = megatron_model_meta.model_provider()
8081
logger.info('Megatron model created successfully.')
@@ -101,7 +102,8 @@ def convert_mcore2hf(args: ExportArguments) -> None:
101102
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype)
102103
patch_megatron_tokenizer(processor)
103104
extra_args = megatron_args.parse_to_megatron()
104-
initialize_megatron(args_defaults=extra_args)
105+
extra_args_provider = megatron_model_meta.extra_args_provider
106+
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args)
105107

106108
mg_model = megatron_model_meta.model_provider()
107109
load_checkpoint([mg_model], None, None, strict=True)

tests/megatron/test_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ def get_mg_model_tokenizer(model_id):
1111
megatron_model_meta = get_megatron_model_meta(processor.model_meta.model_type)
1212
model_info = processor.model_info
1313
kwargs = megatron_model_meta.convert_hf_config(model_info.config)
14-
megatron_args = MegatronArguments(**kwargs, seq_length=1, use_cpu_initialization=True, no_initialization=True)
14+
megatron_args = MegatronArguments(
15+
**kwargs, seq_length=1, use_cpu_initialization=True, no_initialization=True, torch_dtype=torch.float32)
16+
extra_args_provider = megatron_model_meta.extra_args_provider
1517
patch_megatron_tokenizer(processor)
1618
extra_args = megatron_args.parse_to_megatron()
17-
initialize_megatron(args_defaults=extra_args)
19+
initialize_megatron(args_defaults=extra_args, extra_args_provider=extra_args_provider)
1820
mg_model = megatron_model_meta.model_provider()
1921
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
2022
return hf_model, mg_model, processor

0 commit comments

Comments
 (0)