Skip to content

Introduce FdoStats struct for tracking FDO info. #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -503,18 +503,16 @@ def train_step_fn(
lambda y: jax.make_array_from_process_local_data(global_sharding, y),
x,
)
preprocessed_inputs, stats = map(
make_global_view,
embedding.preprocess_sparse_dense_matmul_input(
features,
feature_weights,
feature_specs,
local_device_count=global_mesh.local_mesh.size,
global_device_count=global_mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
),
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
features,
feature_weights,
feature_specs,
local_device_count=global_mesh.local_mesh.size,
global_device_count=global_mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
)
preprocessed_inputs = make_global_view(preprocessed_inputs)
fdo_client.record(stats)

# ----------------------------------------------------------------------
28 changes: 26 additions & 2 deletions jax_tpu_embedding/sparsecore/lib/core/BUILD
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "CORE_USERS")
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")

@@ -77,6 +77,7 @@ pybind_extension(
name = "input_preprocessing_cc",
srcs = ["input_preprocessing.cc"],
deps = [
":fdo_types",
":input_preprocessing_threads",
":input_preprocessing_util",
"@com_google_absl//absl/container:flat_hash_map",
@@ -89,6 +90,24 @@ pybind_extension(
],
)

pybind_library(
name = "fdo_types",
hdrs = ["fdo_types.h"],
deps = ["@com_google_absl//absl/container:flat_hash_map"],
)

pybind_extension(
name = "fdo_types_cc",
srcs = [
"fdo_types.cc",
"fdo_types.h",
],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@pybind11_abseil//:absl_casters",
],
)

pytype_strict_library(
name = "input_preprocessing",
srcs = [
@@ -160,12 +179,17 @@ pytype_strict_library(
srcs = ["__init__.py"],
# C++ dependencies must go in "data".
data = [
":fdo_types", # buildcleaner: keep
":input_preprocessing_threads", # buildcleaner: keep
":input_preprocessing_util", # buildcleaner: keep
],
visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"],
visibility = [
"//jax_tpu_embedding/sparsecore/lib:__pkg__",
"//jax_tpu_embedding/sparsecore/lib/extensions:__pkg__",
],
deps = [
":constants", # buildcleaner: keep
":fdo_types_cc", # buildcleaner: keep
":input_preprocessing", # buildcleaner: keep
":input_preprocessing_cc", # buildcleaner: keep
"//jax_tpu_embedding/sparsecore/lib/core/primitives", # buildcleaner: keep
36 changes: 36 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/fdo_types.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2024 The JAX SC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"

#include "pybind11/cast.h" // from @pybind11
#include "pybind11/numpy.h" // from @pybind11
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/pytypes.h" // from @pybind11
#include "pybind11/stl.h" // from @pybind11
#include "third_party/pybind11_abseil/absl_casters.h"

namespace jax_sc_embedding {

namespace py = ::pybind11;

PYBIND11_MODULE(fdo_types_cc, m) {
py::class_<FdoStats>(m, "FdoStats")
.def_readonly("max_ids_per_partition", &FdoStats::max_ids_per_partition)
.def_readonly("max_unique_ids_per_partition",
&FdoStats::max_unique_ids_per_partition)
.def_readonly("id_drop_counters", &FdoStats::id_drop_counters)
.def_readonly("required_buffer_sizes", &FdoStats::required_buffer_sizes);
}

} // namespace jax_sc_embedding
40 changes: 40 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/fdo_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 The JAX SC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_

#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h" // from @com_google_absl

namespace jax_sc_embedding {

struct FdoStats {
using FdoStatsPerSparseCore = std::vector<int>;

using StackedTableName = std::string;

using FdoStatsPerStackedTable =
absl::flat_hash_map<StackedTableName, FdoStatsPerSparseCore>;

FdoStatsPerStackedTable max_ids_per_partition;
FdoStatsPerStackedTable max_unique_ids_per_partition;
FdoStatsPerStackedTable id_drop_counters;
FdoStatsPerStackedTable required_buffer_sizes;
};

} // namespace jax_sc_embedding

#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
74 changes: 36 additions & 38 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <utility>
@@ -24,6 +26,7 @@
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
#include "pybind11/cast.h" // from @pybind11
@@ -250,8 +253,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
const absl::string_view stacked_table_name, const bool allow_id_dropping,
py::array_t<int> row_pointer_buffer, py::array_t<int> embedding_id_buffer,
py::array_t<int> sample_id_buffer, py::array_t<float> gain_buffer,
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer,
py::array_t<int> required_buffer_size_per_sc_buffer) {
absl::Span<int> max_ids_buffer, absl::Span<int> max_unique_ids_buffer,
absl::Span<int> required_buffer_size_per_sc_buffer) {
const int num_scs = num_sc_per_device * num_global_devices;
int batch_size_for_device = 0;
int total_num_coo_tensors = 0;
@@ -299,10 +302,6 @@ void PreprocessInputForStackedTablePerLocalDevice(
auto* embedding_ids_data = embedding_id_buffer.mutable_data();
auto* sample_ids_data = sample_id_buffer.mutable_data();
auto* gains_data = gain_buffer.mutable_data();
auto* total_max_ids_per_sc = max_ids_buffer.mutable_data();
auto* total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data();
auto* required_buffer_size_per_sc =
required_buffer_size_per_sc_buffer.mutable_data();
// The remaining section does not require GIL.
py::gil_scoped_release release;

@@ -318,8 +317,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
stacked_table_metadata[0].max_ids_per_partition,
stacked_table_metadata[0].max_unique_ids_per_partition,
stacked_table_name, allow_id_dropping, num_sc_per_device,
total_num_coo_tensors, total_max_ids_per_sc,
total_max_unique_ids_per_sc, required_buffer_size_per_sc);
total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer,
required_buffer_size_per_sc_buffer);
for (int i = 0; i < num_sc_per_device; ++i) {
coo_tensors_by_id[i].emplace_back(batch_size_per_sc * (i + 1), 0, 0.0);
}
@@ -359,6 +358,13 @@ static inline py::slice GetBufferSliceForGivenDevice(bool has_leading_dimension,
(start_index + 1) * first_dim_size, 1);
}

static inline absl::Span<int> GetStatsSliceForGivenDevice(
std::vector<int>& stats, int device_index, int stats_size_per_device) {
return absl::MakeSpan(stats).subspan(
device_index * stats_size_per_device,
(device_index + 1) * stats_size_per_device);
}

py::tuple PreprocessSparseDenseMatmulInput(
py::list features, py::list feature_weights, py::list feature_specs,
const int local_device_count, const int global_device_count,
@@ -379,9 +385,9 @@ py::tuple PreprocessSparseDenseMatmulInput(
py::dict lhs_embedding_ids;
py::dict lhs_sample_ids;
py::dict lhs_gains;
py::dict max_ids_per_partition;
py::dict max_unique_ids_per_partition;
py::dict required_buffer_sizes;
FdoStats::FdoStatsPerStackedTable max_ids_per_partition;
FdoStats::FdoStatsPerStackedTable max_unique_ids_per_partition;
FdoStats::FdoStatsPerStackedTable required_buffer_sizes;
const int num_scs = num_sc_per_device * global_device_count;
const int row_pointers_size_per_sc = std::max(num_scs, 8);

@@ -437,15 +443,10 @@ py::tuple PreprocessSparseDenseMatmulInput(
py::array_t<float> gains_per_device =
py::array_t<float>(shape_container);
const int stats_size_per_device = num_scs;
py::array::ShapeContainer stats_shape = GetArrayShapeBasedOnLeadingDim(
/*has_leading_dimension=*/false, local_device_count,
stats_size_per_device);
py::array_t<int> max_ids_per_partition_per_sc =
py::array_t<int>(stats_shape);
py::array_t<int> max_unique_ids_per_partition_per_sc =
py::array_t<int>(stats_shape);
py::array_t<int> required_buffer_size_per_sc =
py::array_t<int>(stats_shape);
size_t stats_size = local_device_count * stats_size_per_device;
std::vector<int> max_ids_per_partition_per_sc(stats_size);
std::vector<int> max_unique_ids_per_partition_per_sc(stats_size);
std::vector<int> required_buffer_size_per_sc(stats_size);
for (int local_device = 0; local_device < local_device_count;
++local_device) {
// Get the tuple outputs for the current split.
@@ -459,15 +460,14 @@ py::tuple PreprocessSparseDenseMatmulInput(
embedding_ids_per_device[static_buffer_slice];
auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
auto gain_buffer = gains_per_device[static_buffer_slice];
py::slice stats_slice =
GetBufferSliceForGivenDevice(/*has_leading_dimension=*/false,
local_device, stats_size_per_device);
auto max_ids_per_partition_per_sc_buffer =
max_ids_per_partition_per_sc[stats_slice];
auto max_unique_ids_per_partition_per_sc_buffer =
max_unique_ids_per_partition_per_sc[stats_slice];
auto required_buffer_size_per_sc_buffer =
required_buffer_size_per_sc[stats_slice];
auto device_max_ids_per_partition =
GetStatsSliceForGivenDevice(max_ids_per_partition_per_sc,
local_device, stats_size_per_device);
auto device_max_unique_ids_per_partition =
GetStatsSliceForGivenDevice(max_unique_ids_per_partition_per_sc,
local_device, stats_size_per_device);
auto device_required_buffer_size = GetStatsSliceForGivenDevice(
required_buffer_size_per_sc, local_device, stats_size_per_device);
PreprocessInputForStackedTablePerLocalDevice(
stacked_table_metadata, features, feature_weights, local_device,
local_device_count, coo_buffer_size, row_pointers_size_per_sc,
@@ -477,10 +477,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
py::cast<py::array_t<int>>(embedding_id_buffer),
py::cast<py::array_t<int>>(sample_id_buffer),
py::cast<py::array_t<float>>(gain_buffer),
py::cast<py::array_t<int>>(max_ids_per_partition_per_sc_buffer),
py::cast<py::array_t<int>>(
max_unique_ids_per_partition_per_sc_buffer),
py::cast<py::array_t<int>>(required_buffer_size_per_sc_buffer));
device_max_ids_per_partition, device_max_unique_ids_per_partition,
device_required_buffer_size);
}
lhs_row_pointers[stacked_table_name.c_str()] =
std::move(row_pointers_per_device);
@@ -500,11 +498,11 @@ py::tuple PreprocessSparseDenseMatmulInput(
}
counter.Wait();
}
py::dict stats;
stats["max_ids"] = max_ids_per_partition;
stats["max_unique_ids"] = max_unique_ids_per_partition;
stats["required_buffer_size"] = std::move(required_buffer_sizes);

FdoStats stats{
.max_ids_per_partition = max_ids_per_partition,
.max_unique_ids_per_partition = max_unique_ids_per_partition,
.required_buffer_sizes = required_buffer_sizes,
};
// GIL is held at this point.
return py::make_tuple(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids,
lhs_gains, stats);
Original file line number Diff line number Diff line change
@@ -763,7 +763,7 @@ def test_multi_process_fdo(self, has_leading_dimension):
allow_id_dropping=False,
)
)
stats = embedding.SparseDenseMatmulInputStats.from_dict(stats)
stats = embedding.SparseDenseMatmulInputStats.from_cc(stats)
fdo_client.record(stats)
fdo_client.publish()
# Duplicated ids on row 0 and 6 are combined.
Original file line number Diff line number Diff line change
@@ -105,8 +105,8 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
const int32_t max_unique_ids_per_partition,
const absl::string_view stacked_table_name, const bool allow_id_dropping,
const int num_sc_per_device, const int total_num_coo_tensors,
int max_ids_per_sc[], int max_unique_ids_per_sc[],
int required_buffer_size_per_sc[]) {
absl::Span<int> max_ids_per_sc, absl::Span<int> max_unique_ids_per_sc,
absl::Span<int> required_buffer_size_per_sc) {
tsl::profiler::TraceMe t("SortAndGroupCooTensors");
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
std::vector<std::vector<CooFormat>> coo_tensors_by_id;
Original file line number Diff line number Diff line change
@@ -120,18 +120,19 @@ struct StackedTableMetadata {

std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
absl::Span<const CooFormat> coo_tensors, int batch_size_per_sc,
int global_sc_count,
int32_t batch_size_for_device, // Batch size for the local device.
int global_sc_count, int32_t batch_size_for_device,
int32_t max_ids_per_partition, int32_t max_unique_ids_per_partition,
absl::string_view stacked_table_name, bool allow_id_dropping,
int num_sc_per_device, int total_num_coo_tensors, int max_ids_per_sc[],
int max_unique_ids_per_sc[], int required_buffer_size_per_sc[]);
int num_sc_per_device, int total_num_coo_tensors,
absl::Span<int> max_ids_per_sc, absl::Span<int> max_unique_ids_per_sc,
absl::Span<int> required_buffer_size_per_sc);

int ComputeCooBufferSize(
int num_scs, int num_scs_per_device,
absl::Span<const StackedTableMetadata> stacked_table_metadata,
int static_buffer_size_multiplier);


void IncrementScId(std::pair<int, int>& sc_id, int num_scs,
int num_scs_per_device);

1 change: 1 addition & 0 deletions jax_tpu_embedding/sparsecore/lib/nn/BUILD
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ pytype_strict_library(
deps = [
":embedding_spec",
":table_stacking",
"//jax_tpu_embedding/sparsecore/lib/core:fdo_types_cc",
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr",
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",
11 changes: 6 additions & 5 deletions jax_tpu_embedding/sparsecore/lib/nn/embedding.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
from jax.experimental.layout import DeviceLocalLayout as DLL
from jax.experimental.layout import Layout
import jax.numpy as jnp
from jax_tpu_embedding.sparsecore.lib.core import fdo_types_cc
from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
@@ -69,12 +70,12 @@ class SparseDenseMatmulInputStats:
max_unique_ids_per_partition: Mapping[str, np.ndarray]

@classmethod
def from_dict(
cls, stats: Mapping[str, Mapping[str, np.ndarray]]
def from_cc(
cls, stats: fdo_types_cc.FdoStats
) -> "SparseDenseMatmulInputStats":
return cls(
max_ids_per_partition=stats["max_ids"],
max_unique_ids_per_partition=stats["max_unique_ids"],
max_ids_per_partition=stats.max_ids_per_partition,
max_unique_ids_per_partition=stats.max_unique_ids_per_partition,
)


@@ -380,7 +381,7 @@ def preprocess_sparse_dense_matmul_input(

return SparseDenseMatmulInput(
*preprocessed_inputs
), SparseDenseMatmulInputStats.from_dict(stats)
), SparseDenseMatmulInputStats.from_cc(stats)


def _get_activation_for_feature(
1 change: 1 addition & 0 deletions jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ py_binary(
name = "preprocess_input_benchmarks",
srcs = ["preprocess_input_benchmarks.py"],
deps = [
"//jax_tpu_embedding/sparsecore/lib/core:fdo_types_cc",
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
pypi_requirement("google_benchmark"),
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
"""

import google_benchmark
from jax_tpu_embedding.sparsecore.lib.core import fdo_types_cc # pylint: disable=unused-import # used for type conversion of stat data type
from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
import numpy as np
@@ -108,6 +109,7 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False):
all_feature_weights.append(np.array(feature_weights, dtype=object))
return all_features, all_feature_weights


# Total local batch size that is measured is 16000x100 = 1,600,000.
_GLOBAL_SPECS = generate_feature_specs(num_features=100)
_GLOBAL_RAGGED_FEATURES, _GLOBAL_RAGGED_FEATURE_WEIGHTS = (