You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1).
This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such.
Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists.
This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases.
Differential Revision: D58894728
0 commit comments