diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 651e489..7dbb9ed 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -514,7 +514,7 @@ def tpu_sparse_dense_matmul( embedding_id, sample_id, gain, - embedding_variable[0], # [0] is the embedding table + embedding_variable.table, device_batch_size=stacked_table.total_sample_count // global_device_count, max_ids_per_partition=stacked_table.max_ids_per_partition,