4
4
from dataclasses import asdict , dataclass , field
5
5
from typing import Any , Dict , List , Literal , Optional , Tuple , Union
6
6
7
+ import json
7
8
import megatron .core
8
9
import torch
9
10
from packaging import version
@@ -167,6 +168,9 @@ class MegatronArguments(ExtraMegatronArguments):
167
168
num_workers : int = 4
168
169
no_create_attention_mask_in_dataloader : bool = True
169
170
171
+ # extra_args for megatron
172
+ extra_megatron_kwargs : Optional [str ] = None
173
+
170
174
def _set_default (self ):
171
175
if self .num_query_groups is None :
172
176
self .num_query_groups = 1
@@ -231,6 +235,17 @@ def __post_init__(self):
231
235
232
236
self .tensorboard_dir = to_abspath (self .tensorboard_dir )
233
237
238
+ try :
239
+ if self .extra_megatron_kwargs is None :
240
+ self .extra_megatron_kwargs = {}
241
+ elif isinstance (self .extra_megatron_kwargs , str ):
242
+ self .extra_megatron_kwargs = json .loads (self .extra_megatron_kwargs )
243
+ elif isinstance (self .extra_megatron_kwargs , dict ):
244
+ # For loading from config file
245
+ self .extra_megatron_kwargs = self .extra_megatron_kwargs
246
+ except json .JSONDecodeError :
247
+ raise ValueError ('extra_megatron_kwargs should be a valid json string' )
248
+
234
249
def _args_to_argv (self ) -> Tuple [List [Any ], Dict [str , Any ]]:
235
250
new_args = []
236
251
args_dict = asdict (self )
@@ -241,6 +256,15 @@ def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
241
256
if k not in MegatronArguments .__annotations__ :
242
257
extra_args [k ] = value
243
258
continue
259
+ if k == 'extra_megatron_kwargs' :
260
+ if isinstance (value , str ):
261
+ value = json .loads (value )
262
+ if not isinstance (value , dict ):
263
+ raise ValueError (f'extra_megatron_kwargs should be a dict, but got { type (value )} ' )
264
+ for sub_key , sub_value in value .items ():
265
+ new_args .append (f"--{ sub_key .replace ('_' , '-' )} " )
266
+ new_args .append (str (sub_value ))
267
+ continue
244
268
if value is None or value is False :
245
269
continue
246
270
new_args .append (f"--{ k .replace ('_' , '-' )} " )
0 commit comments