Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,23 +866,46 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
Example::

kjt_split = [1, 2]

# the kjt_split informs the number of features owned by each rank, here t0 owns f0 and
# t1 owns f1 and f2.

emb_dim_per_rank_per_feature = [[2], [3, 3]]
a2a = VariableBatchPooledEmbeddingsAllToAll(
pg, emb_dim_per_rank_per_feature, device
)

t0 = torch.rand(6) # 2 * (2 + 1)
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
t1 = torch.rand(24) # 3 * (1 + 2) + 3 * (3 + 2)

# t0 and t1 are the flattened send buffers of pooled embedding outputs produced on the
# ranks that own the features, computed as embedding_dim * (sum of variable batch sizes
# for that feature across all source ranks), summed over the features owned by that destination rank.

# r0_batch_size r1_batch_size
# f_0: 2 1
-----------------------------------------
# f_1: 1 2
# f_2: 3 2

# batch_size_per_rank_per_feature tensor is specified from the perspective of the sending rank
# outer_index = destination rank, inner vector = features ownwed by the sending rank (in emb_dim_per_rank_per_feature order)

r0_batch_size_per_rank_per_feature = [[2], [1]]
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]

# r0 wants r1 wants
# f0: 2 1
# f1: 1 2
# f2: 3 2
# which informs the per_feature_pre_a2a vectors

r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]

# r0 should recieve f0: 2 (from r0), f1: 1 (from r1), f2: 3 (from r1)
# r1 should recieve f0: 1 (from r0), f1: 2 (from r1), f2: 2 (from r1)

rank0_output = a2a(
t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
).wait()
Expand Down
Loading