@@ -182,10 +182,11 @@ def finalize_layer_weights(self):
182
182
offset = 0
183
183
for name in self .names :
184
184
for expert_id in range (self .expert_start , self .expert_end ):
185
- t = self .shared_tensors [(expert_id , name )]
185
+ t = self .shared_tensors [(expert_id , name )]. contiguous (). cpu ()
186
186
data_size = t .numel () * t .element_size ()
187
187
aligned_size = self .align_size (data_size )
188
- shm .buf [offset :offset + data_size ] = t .numpy ().tobytes ()
188
+ shm .buf [offset :offset + data_size ] = t .flatten ().view (
189
+ torch .int8 ).numpy ().tobytes ()
189
190
dtype = t .dtype
190
191
tensor_shape = t .shape
191
192
elt_count = t .numel ()
@@ -270,7 +271,8 @@ def __init__(
270
271
single_layer_load_balancer_impl : _tbr .SingleLayerMoeLoadBalancer ,
271
272
shared_mpi_comm : MPI .Comm ,
272
273
expert_count : int ,
273
- updates_enabled : bool = True ):
274
+ updates_enabled : bool = True ,
275
+ repeated_count = 1 ):
274
276
"""
275
277
Initialize a SingleLayerMoeLoadBalancer instance.
276
278
@@ -279,6 +281,7 @@ def __init__(
279
281
shared_mpi_comm: The MPI communicator for shared memory
280
282
expert_count: total number of experts
281
283
updates_enabled: whether to enable weight updates
284
+ repeated_count: the repeated count of current layer, used when forward is repeated more than once like MTP.
282
285
"""
283
286
self .single_layer_load_balancer_impl = single_layer_load_balancer_impl
284
287
self .single_layer_load_balancer_ptr = single_layer_load_balancer_impl .get_pointer (
@@ -306,6 +309,7 @@ def __init__(
306
309
307
310
self .cudagraph_stream = None
308
311
self .cudagraph_event = None
312
+ self .repeated_count = repeated_count
309
313
310
314
self .statistic_stream = None
311
315
self .statistic_event = None
@@ -317,6 +321,9 @@ def get_load_expert_ids(self):
317
321
assert self .updates_enabled , "should not call get_load_expert_ids when using statistic routing"
318
322
return self .load_expert_ids
319
323
324
+ def get_repeat_count (self ):
325
+ return self .repeated_count
326
+
320
327
def is_static_routing (self ):
321
328
return not self .updates_enabled
322
329
@@ -675,6 +682,8 @@ def __init__(self,
675
682
self .enable_statistic = False
676
683
self .enable_update_weights = False
677
684
685
+ self .next_layer_repeated_count = None
686
+
678
687
def __del__ (self ):
679
688
if not self .is_shutdown :
680
689
self .shutdown ()
@@ -696,6 +705,16 @@ def _setup_mpi_comm(self):
696
705
def set_use_gpu_memcpy (self , use_gpu_memcpy : bool ):
697
706
self .load_balancer_impl .set_use_gpu_memcpy (use_gpu_memcpy )
698
707
708
+ def set_repeated_for_next_layer (self , repeated_count : int ):
709
+ """
710
+ Set repeat count for next layer.
711
+
712
+ Args:
713
+ repeated_count: The repeat count for next layer
714
+ """
715
+ assert repeated_count > 0 , "repeat count must be greater than 0"
716
+ self .next_layer_repeated_count = repeated_count
717
+
699
718
def add_layer (self , expert_count : int , top_k : int ,
700
719
slot_count_per_rank : int ) -> SingleLayerMoeLoadBalancer :
701
720
"""
@@ -712,11 +731,16 @@ def add_layer(self, expert_count: int, top_k: int,
712
731
single_layer_load_balancer_impl = self .load_balancer_impl .add_layer (
713
732
expert_count , top_k , slot_count_per_rank )
714
733
updates_enabled = not self .is_static_routing ()
734
+ repeat_count = 1
735
+ if self .next_layer_repeated_count is not None :
736
+ repeat_count = self .next_layer_repeated_count
737
+ self .next_layer_repeated_count = None
715
738
single_layer_load_balancer = SingleLayerMoeLoadBalancer (
716
739
single_layer_load_balancer_impl ,
717
740
self .shared_mpi_comm ,
718
741
expert_count ,
719
- updates_enabled = updates_enabled )
742
+ updates_enabled = updates_enabled ,
743
+ repeated_count = repeat_count )
720
744
single_layer_load_balancer .set_shared_memory_base_name (
721
745
self .shared_memory_base_name )
722
746
self .single_layer_load_balancers .append (single_layer_load_balancer )
@@ -934,6 +958,18 @@ def get_moe_load_balancer() -> Optional[MoeLoadBalancer]:
934
958
return None
935
959
936
960
961
+ def moe_load_balancer_set_repeated_for_next_layer (repeat_count : int ):
962
+ """
963
+ Set repeated count for next Single Layer created.
964
+
965
+ Args:
966
+ repeat_count: repeated count
967
+ """
968
+ load_balancer = get_moe_load_balancer ()
969
+ if load_balancer is not None :
970
+ load_balancer .set_repeated_for_next_layer (repeat_count )
971
+
972
+
937
973
def moe_load_balancer_add_single_layer (
938
974
expert_count : int , top_k : int ,
939
975
slot_count_per_rank : int ) -> Optional [SingleLayerMoeLoadBalancer ]:
0 commit comments