diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 1cf60302d..8b301cf8b 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = fdac107 + Default = f70c54d current git hash of repository diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 807faac75..9c86d98d3 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -37,7 +37,7 @@ ParallelLinear, ) from megatron.model.gmlp import GMLPBlock -from megatron.model.mamba import MambaResidualLayerPipe +from megatron.model.mamba import ParallelMambaResidualLayerPipe from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding # Pipeline parallelism @@ -138,7 +138,7 @@ def __init__( checkpointable_layers=[ "GMLPBlock", "ParallelTransformerLayerPipe", - "MambaResidualLayerPipe", + "ParallelMambaResidualLayerPipe", ], ) @@ -174,7 +174,7 @@ def insert_layers( checkpointable_layers=[ "GMLPBlock", "ParallelTransformerLayerPipe", - "MambaResidualLayerPipe", + "ParallelMambaResidualLayerPipe", ], ) @@ -254,7 +254,7 @@ def init_specs(self): elif layer_type in ["mamba"]: self.specs.append( LayerSpec( - MambaResidualLayerPipe, + ParallelMambaResidualLayerPipe, neox_args=self.neox_args, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, diff --git a/megatron/model/mamba/__init__.py b/megatron/model/mamba/__init__.py index b55311091..a024707ad 100644 --- a/megatron/model/mamba/__init__.py +++ b/megatron/model/mamba/__init__.py @@ -1 +1,4 @@ -from .mamba import MambaResidualLayer, MambaResidualLayerPipe +from .mamba import ( + ParallelMambaResidualLayer, + ParallelMambaResidualLayerPipe, +) diff --git a/megatron/model/mamba/mamba.py b/megatron/model/mamba/mamba.py index eedd19f40..d5d6b336f 100644 --- a/megatron/model/mamba/mamba.py +++ b/megatron/model/mamba/mamba.py @@ -19,10 +19,10 @@ pass from megatron.model.norms import get_norm +from megatron import mpu - -# Mamba layer, without parallelism. -class MambaBlock(nn.Module): +# Mamba sublayer, with tensor parallelism +class ParallelMambaBlock(nn.Module): def __init__( self, neox_args, @@ -58,22 +58,33 @@ def __init__( self.dt_min, self.dt_max, self.dt_init_floor = 0.001, 0.1, 1e-4 assert self.dt_init in ["constant", "random"] + # TP-specific setup + world_size = mpu.get_model_parallel_world_size() + self.d_inner_per_rank = mpu.divide(self.d_inner, world_size) + + if neox_args.mamba_inner_func_fusion and world_size > 1: + # as with gpt-j residual, we must manually reduce output from final proj + # across TP ranks, since it is not done by fused mamba_inner_fn . + self.reduce = mpu.mappings.reduce_from_model_parallel_region + # up-projection. - self.in_proj = nn.Linear( - self.d_model, - self.d_inner * 2, + self.in_proj = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=self.d_model, + output_size=self.d_inner * 2, + gather_output=False, + init_method=init_method, + skip_bias_add=not neox_args.mamba_use_bias_in_linears, bias=neox_args.mamba_use_bias_in_linears, - **factory_kwargs, ) - init_method(self.in_proj.weight) - # convolution. + # convolution (parallelized across d_inner) self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, + in_channels=self.d_inner_per_rank, + out_channels=self.d_inner_per_rank, bias=neox_args.mamba_use_bias_in_conv, kernel_size=self.d_conv, - groups=self.d_inner, + groups=self.d_inner_per_rank, padding=self.d_conv - 1, **factory_kwargs, ) @@ -81,27 +92,30 @@ def __init__( # Uncertain why self.conv1d.to(self.precision) - self.act_fn = F.silu # we do not allow for + self.act_fn = F.silu # we do not allow for other activation fns # x_proj corresponds to s_B(x), s_C(x), s_Delta(x) # in https://arxiv.org/pdf/2312.00752.pdf Algorithm 2 # (computes data-dependent B, C, Delta/dt) - self.x_proj = nn.Linear( - self.d_inner, - self.dt_rank + self.d_state * 2, + self.x_proj = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=self.d_inner, + output_size=self.dt_rank + self.d_state * 2, + input_is_parallel=True, + init_method=init_method, + skip_bias_add=not neox_args.mamba_use_bias_in_linears, + parallel_output=True, bias=neox_args.mamba_use_bias_in_linears, - **factory_kwargs, ) - init_method(self.x_proj.weight) # up-project dt / Delta from dt_rank to d_inner - # dt_proj 's bias is a special case and I believe we should keep it turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752) + # dt_proj 's bias is a special case and should be kept always turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752) # defines Delta as Delta = Tau_{Delta}(Parameter + s_{Delta}(x)) where s_{Delta}(x) = Broadcast_{D}(Linear_{1}(x)) # or as they further explain in section 3.6 can be also s_{Delta}(x) = Linear_{D}(Linear_{R}(x)) where Linear_R # is the delta portion of x_proj and Linear_D is the dt_proj weight. Then, the Parameter term from Alg. 2 can # be viewed as the bias term in dt_proj, with a special initialization from https://arxiv.org/abs/2206.12037 self.dt_proj = nn.Linear( - self.dt_rank, self.d_inner, bias=True, **factory_kwargs + self.dt_rank, self.d_inner_per_rank, bias=True, **factory_kwargs ) # special init for dt_proj @@ -115,7 +129,7 @@ def __init__( # more dt_proj init stuff. copied from https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L91-L101 dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) + torch.rand(self.d_inner_per_rank, **factory_kwargs) * (math.log(self.dt_max) - math.log(self.dt_min)) + math.log(self.dt_min) ).clamp(min=self.dt_init_floor) @@ -133,7 +147,7 @@ def __init__( device=torch.cuda.current_device(), ), "n -> d n", - d=self.d_inner, + d=self.d_inner_per_rank, ).contiguous() A_log = torch.log(A).to( torch.float32 @@ -150,7 +164,9 @@ def __init__( # D parameter self.D = nn.Parameter( torch.ones( - self.d_inner, device=torch.cuda.current_device(), dtype=torch.float32 + self.d_inner_per_rank, + device=torch.cuda.current_device(), + dtype=torch.float32, ) ).to( torch.float32 @@ -163,14 +179,20 @@ def __init__( if self.neox_args.mamba_selective_fp32_params: self.D._deepspeed_no_cast = True - # out down-projection - self.out_proj = nn.Linear( - self.d_inner, - self.d_model, + # out down-projection. + # use "single_residual_scaled_normal" + # for output_layer_init_method + # to perform gpt-2 style scaled init as done in Mamba paper. + self.out_proj = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=self.d_inner, + output_size=self.d_model, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=not neox_args.mamba_use_bias_in_linears, bias=neox_args.mamba_use_bias_in_linears, - **factory_kwargs, + parallel_output=False, ) - output_layer_init_method(self.out_proj.weight) def selective_scan( self, @@ -224,14 +246,8 @@ def forward(self, hidden_states): seqlen, batch, dim = hidden_states.shape # first up: perform in_proj - xz = einops.rearrange( - self.in_proj.weight @ einops.rearrange(hidden_states, "l b d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - - if self.in_proj.bias is not None: - xz = xz + einops.rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + xz, _ = self.in_proj(hidden_states) + xz = einops.rearrange(xz, "l b d -> b d l") A = -torch.exp(self.A_log.float()) # (d_inner, d_state) @@ -262,6 +278,12 @@ def forward(self, hidden_states): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) + if getattr(self, "reduce", None): + # manually reduce after mamba_inner_fn + # to collect outputs from different TP ranks. + # handled by running self.out_proj(y) below + # so only needed here. + out = self.reduce(out) out = einops.rearrange(out, "b l h -> l b h") @@ -292,7 +314,7 @@ def forward(self, hidden_states): # ============== # project: perform s_B, s_C, s_Delta projections - x_dbl = self.x_proj(einops.rearrange(x, "b d l -> (b l) d")) + x_dbl, _ = self.x_proj(einops.rearrange(x, "b d l -> (b l) d")) # split into component dt, B, C dt, B, C = torch.split( x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 @@ -324,14 +346,14 @@ def forward(self, hidden_states): # =============== y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) + out, _ = self.out_proj(y) out = einops.rearrange(out, "b l h -> l b h") return out -class MambaResidualLayer(nn.Module): +class ParallelMambaResidualLayer(nn.Module): """ Pre-norm Mamba Block with residual connection. No parallelism yet supported. """ @@ -352,7 +374,7 @@ def __init__( self.norm = norm(neox_args.hidden_size, eps=eps) - self.mixer = MambaBlock( + self.mixer = ParallelMambaBlock( neox_args=neox_args, init_method=init_method, output_layer_init_method=output_layer_init_method, @@ -369,7 +391,7 @@ def forward(self, x, attention_mask=None, layer_past=None): return hidden_states + residual -class MambaResidualLayerPipe(MambaResidualLayer): +class ParallelMambaResidualLayerPipe(ParallelMambaResidualLayer): """Extends MambaResidualLayer to forward attention_mask through the pipeline. DeepSpeed requires this.""" def forward(self, args): diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index f06374e5d..bf6e3f3e8 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1061,9 +1061,6 @@ def calculate_derived(self): not self.partition_activations ), "GMLP Blocks are not compatible with partition activations" if "mamba" in self.attention_config: - assert ( - not self.is_pipe_parallel and self.model_parallel_size == 1 - ), "Mamba not currently compatible with parallelism" if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba" assert ( diff --git a/tools/ckpts/convert_neox_to_mamba_ssm.py b/tools/ckpts/convert_neox_to_mamba_ssm.py index 386a79a57..f87b36954 100644 --- a/tools/ckpts/convert_neox_to_mamba_ssm.py +++ b/tools/ckpts/convert_neox_to_mamba_ssm.py @@ -27,24 +27,26 @@ """ ARCH = { "COLUMN_PARALLEL_LINEAR_KEYS": { + # these require concat across dim=0 "mixer.in_proj.weight": "mixer.in_proj.weight", # "mixer.in_proj.bias": "mixer.in_proj.bias", + "mixer.A_log": "mixer.A_log", + "mixer.D": "mixer.D", + "mixer.conv1d.weight": "mixer.conv1d.weight", + "mixer.conv1d.bias": "mixer.conv1d.bias", + "mixer.dt_proj.weight": "mixer.dt_proj.weight", + "mixer.dt_proj.bias": "mixer.dt_proj.bias", }, "ROW_PARALLEL_LINEAR_KEYS": { + # these require concat across dim=1 "mixer.out_proj.weight": "mixer.out_proj.weight", + "mixer.x_proj.weight": "mixer.x_proj.weight", }, "ROW_PARALLEL_BIAS_KEYS": { + # these require summing across ranks + # "mixer.x_proj.bias": "mixer.x_proj.bias", # "mixer.out_proj.bias": "mixer.out_proj.bias", }, - "NO_SHARD_KEYS": { - "mixer.A_log": "mixer.A_log", - "mixer.D": "mixer.D", - "mixer.x_proj.weight": "mixer.x_proj.weight", - "mixer.dt_proj.weight": "mixer.dt_proj.weight", - "mixer.dt_proj.bias": "mixer.dt_proj.bias", - "mixer.conv1d.weight": "mixer.conv1d.weight", - "mixer.conv1d.bias": "mixer.conv1d.bias", - }, "NORM_KEYS": { "norm.scale": "norm.weight", # "norm.bias": "norm.bias", @@ -226,15 +228,6 @@ def convert( ) ) - # Average params which aren't sharded across ranks. - # they should be the same across ranks, so should be fine - for key, hf_key in ARCH["NO_SHARD_KEYS"].items(): - state_dict[hf_key] = sum( - get_state( - loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential - ) - ) / len(loaded_tp_ranks) - layer.load_state_dict(state_dict) if not sequential: @@ -320,12 +313,6 @@ def main(input_args=None, overwrite_values=None): action="store_true", help="Whether to skip saving the tokenizer alongside a model.", ) - parser.add_argument( - "--architecture", - type=str, - default="neox", - help="What HF model class type to export into.", - ) args = parser.parse_args(input_args) # validate arguments