Skip to content

Creating RocksDBCheckpointHandler to expose rocksdb checkpoint to python #4224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@ def split_embedding_weights(
no_snapshot=no_snapshot,
should_flush=should_flush,
)
checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)

dtype = self.weights_precision.as_dtype()
if self.load_state_dict and self.kv_zch_params:
Expand Down Expand Up @@ -2286,6 +2287,7 @@ def split_embedding_weights(
sorted_indices=(
bucket_ascending_id_tensor if self.kv_zch_params else None
),
checkpoint_handle=checkpoint_handle,
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
Expand Down Expand Up @@ -2520,6 +2522,12 @@ def flush(self, force: bool = False) -> None:
self.ssd_db.flush()
self.last_flush_step = self.step

def create_rocksdb_hard_link_snapshot(self) -> None:
"""
Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
"""
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)

def prepare_inputs(
self,
indices: Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@

#include <ATen/ATen.h>
#include <folly/hash/Hash.h>
#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>
#include <filesystem>
#include <optional>

/// @defgroup embedding-ssd Embedding SSD Operators
///

namespace kv_db_utils {

#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
#endif

/// @ingroup embedding-ssd
///
/// @brief hash function used for SSD L2 cache and rocksdb sharding algorithm
Expand Down Expand Up @@ -65,4 +70,94 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
std::optional<int64_t> bucket_size,
std::optional<int64_t> total_num_buckets);

/// @ingroup embedding-ssd
///
/// @brief default way to generate rocksdb path based on a user provided
/// base_path the file hierarchy will be
/// <base_path><ssd_idx>/<tbe_uuid> for default SSD mount
/// <base_path>/<tbe_uuid> for user provided base path
///
/// @param base_path the base path for all the rocksdb shards tied to one
/// TBE/EmbeddingRocksDB
/// @param db_shard_id the rocksdb shard index, this is used to determine which
/// SSD to use
/// @param tbe_uuid unique identifier per TBE at the lifetime of a training job
/// @param default_path whether the base_path is default SSD mount or
/// user-provided
///
/// @return the base path to that rocksdb shard
inline std::string get_rocksdb_path(
const std::string& base_path,
int db_shard_id,
const std::string& tbe_uuid,
bool default_path) {
if (default_path) {
int ssd_drive_idx = db_shard_id % num_ssd_drives;
std::string ssd_idx_tbe_id_str =
std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid;
return base_path + ssd_idx_tbe_id_str;
} else {
return base_path + std::string("/") + tbe_uuid;
}
}

/// @ingroup embedding-ssd
///
/// @brief generate rocksdb shard path, based on rocksdb_path
/// the file hierarchy will be
/// <rocksdb_shard_path>/shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the rocksdb shard path
inline std::string get_rocksdb_shard_path(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/shard_") + std::to_string(db_shard_id);
}

/// @ingroup embedding-ssd
///
/// @brief generate a directory to hold rocksdb checkpoint for a particular
/// rocksdb shard path the file hierarchy will be
/// <rocksdb_shard_path>/checkpoint_shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the directory that holds rocksdb checkpoints for one rocksdb shard
inline std::string get_rocksdb_checkpoint_dir(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/checkpoint_shard_") +
std::to_string(db_shard_id);
}

inline void create_dir(const std::string& dir_path) {
try {
std::filesystem::path fs_path(dir_path);
bool res = std::filesystem::create_directories(fs_path);
if (!res) {
LOG(ERROR) << "dir: " << dir_path << " already exists";
}
} catch (const std::exception& e) {
LOG(ERROR) << "Error creating directory: " << e.what();
}
}

inline void remove_dir(const std::string& path) {
if (std::filesystem::exists(path)) {
try {
if (std::filesystem::is_directory(path)) {
std::filesystem::remove_all(path);
} else {
std::filesystem::remove(path);
}
} catch (const std::filesystem::filesystem_error& e) {
LOG(ERROR) << "Error removing path: " << path
<< ", exception:" << e.what();
}
}
}
}; // namespace kv_db_utils
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->get_snapshot_count();
}

void create_rocksdb_hard_link_snapshot(int64_t global_step) {
impl_->create_checkpoint(global_step);
}

c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> get_active_checkpoint_uuid(
int64_t global_step) {
auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step);
if (uuid_opt.has_value()) {
return c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
uuid_opt.value(), impl_);
} else {
return nullptr;
}
}

private:
friend class KVTensorWrapper;

Expand Down
21 changes: 20 additions & 1 deletion fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <ATen/Tensor.h> // @manual=//caffe2:ATen-core
#include <nlohmann/json.hpp>
#include <torch/custom_class.h>

namespace kv_mem {
Expand All @@ -21,7 +22,10 @@ class EmbeddingKVDB;

namespace ssd {

using json = nlohmann::json;

class EmbeddingRocksDB;
class ReadOnlyEmbeddingKVDB;
class EmbeddingRocksDBWrapper;
class SnapshotHandle;

Expand All @@ -37,6 +41,18 @@ struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
std::shared_ptr<EmbeddingRocksDB> db;
};

// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
struct RocksdbCheckpointHandleWrapper : public torch::jit::CustomClassHolder {
explicit RocksdbCheckpointHandleWrapper(
const std::string& checkpoint_uuid,
std::shared_ptr<EmbeddingRocksDB> db);

~RocksdbCheckpointHandleWrapper();

std::string uuid;
std::shared_ptr<EmbeddingRocksDB> db;
};

class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
Expand All @@ -46,7 +62,9 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle = std::nullopt,
std::optional<at::Tensor> sorted_indices = std::nullopt,
int64_t width_offset = 0);
int64_t width_offset = 0,
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> checkpoint_handle =
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>(nullptr));

at::Tensor narrow(int64_t dim, int64_t start, int64_t length);

Expand Down Expand Up @@ -100,6 +118,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
std::optional<at::Tensor> sorted_indices_ = std::nullopt;
int64_t width_offset_;
std::mutex mtx;
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> checkpoint_handle_;
};

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ KVTensorWrapper::KVTensorWrapper(
[[maybe_unused]] const std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle,
[[maybe_unused]] const std::optional<at::Tensor> sorted_indices,
[[maybe_unused]] int64_t width_offset)
[[maybe_unused]] int64_t width_offset,
[[maybe_unused]] c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>)
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
: shape_(std::move(shape)), row_offset_(row_offset) {
FBEXCEPTION("Not implemented");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "embedding_rocksdb_wrapper.h"
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
#include "fbgemm_gpu/utils/ops_utils.h"

#include "rocksdb/utilities/checkpoint.h"
using namespace at;
using namespace ssd;
using namespace kv_mem;
Expand Down Expand Up @@ -293,6 +293,37 @@ snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
return shard_snapshots_[shard];
}

CheckpointHandle::CheckpointHandle(
EmbeddingRocksDB* db,
const std::string& tbe_uuid,
const std::string& ckpt_uuid,
const std::string& base_path,
bool use_default_ssd_path)
: db_(db), ckpt_uuid_(ckpt_uuid) {
auto num_shards = db->num_shards();
CHECK_GT(num_shards, 0);
shard_checkpoints_.reserve(num_shards);
for (auto shard = 0; shard < num_shards; ++shard) {
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
base_path, shard, tbe_uuid, use_default_ssd_path);
auto checkpoint_shard_dir =
kv_db_utils::get_rocksdb_checkpoint_dir(shard, rocksdb_path);
kv_db_utils::create_dir(checkpoint_shard_dir);
rocksdb::Checkpoint* checkpoint = nullptr;
rocksdb::Status s =
rocksdb::Checkpoint::Create(db->dbs_[shard].get(), &checkpoint);
CHECK(s.ok()) << "ERROR: Checkpoint init for tbe_uuid " << tbe_uuid
<< ", db shard " << shard << " failed, " << s.code() << ", "
<< s.ToString();
std::string checkpoint_shard_path = checkpoint_shard_dir + "/" + ckpt_uuid_;
s = checkpoint->CreateCheckpoint(checkpoint_shard_path);
CHECK(s.ok()) << "ERROR: Checkpoint creation for tbe_uuid " << tbe_uuid
<< ", db shard " << shard << " failed, " << s.code() << ", "
<< s.ToString();
shard_checkpoints_.push_back(checkpoint_shard_path);
}
}

EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
const SnapshotHandle* handle,
std::shared_ptr<EmbeddingRocksDB> db)
Expand All @@ -302,14 +333,24 @@ EmbeddingSnapshotHandleWrapper::~EmbeddingSnapshotHandleWrapper() {
db->release_snapshot(handle);
}

RocksdbCheckpointHandleWrapper::RocksdbCheckpointHandleWrapper(
const std::string& checkpoint_uuid,
std::shared_ptr<EmbeddingRocksDB> db)
: uuid(checkpoint_uuid), db(std::move(db)) {}

RocksdbCheckpointHandleWrapper::~RocksdbCheckpointHandleWrapper() {
db->release_checkpoint(uuid);
}

KVTensorWrapper::KVTensorWrapper(
std::vector<int64_t> shape,
int64_t dtype,
int64_t row_offset,
const std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle,
std::optional<at::Tensor> sorted_indices,
int64_t width_offset_)
int64_t width_offset_,
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> checkpoint_handle)
: db_(nullptr),
shape_(std::move(shape)),
row_offset_(row_offset),
Expand All @@ -333,6 +374,7 @@ KVTensorWrapper::KVTensorWrapper(
if (sorted_indices.has_value()) {
sorted_indices_ = sorted_indices;
}
checkpoint_handle_ = checkpoint_handle;
}

void KVTensorWrapper::set_embedding_rocks_dp_wrapper(
Expand Down Expand Up @@ -473,6 +515,11 @@ static auto embedding_snapshot_handle_wrapper =
"fbgemm",
"EmbeddingSnapshotHandleWrapper");

static auto rocksdb_checkpoint_handle_wrapper =
torch::class_<RocksdbCheckpointHandleWrapper>(
"fbgemm",
"RocksdbCheckpointHandleWrapper");

static auto embedding_rocks_db_wrapper =
torch::class_<EmbeddingRocksDBWrapper>("fbgemm", "EmbeddingRocksDBWrapper")
.def(
Expand Down Expand Up @@ -584,7 +631,13 @@ static auto embedding_rocks_db_wrapper =
.def("get_snapshot_count", &EmbeddingRocksDBWrapper::get_snapshot_count)
.def(
"get_keys_in_range_by_snapshot",
&EmbeddingRocksDBWrapper::get_keys_in_range_by_snapshot);
&EmbeddingRocksDBWrapper::get_keys_in_range_by_snapshot)
.def(
"create_rocksdb_hard_link_snapshot",
&EmbeddingRocksDBWrapper::create_rocksdb_hard_link_snapshot)
.def(
"get_active_checkpoint_uuid",
&EmbeddingRocksDBWrapper::get_active_checkpoint_uuid);

static auto dram_kv_embedding_cache_wrapper =
torch::class_<DramKVEmbeddingCacheWrapper>(
Expand Down Expand Up @@ -667,7 +720,8 @@ static auto kv_tensor_wrapper =
std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>,
std::optional<at::Tensor>,
int64_t>(),
int64_t,
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>(),
"",
{torch::arg("shape"),
torch::arg("dtype"),
Expand All @@ -676,7 +730,9 @@ static auto kv_tensor_wrapper =
// not needed for writing
torch::arg("snapshot_handle") = std::nullopt,
torch::arg("sorted_indices") = std::nullopt,
torch::arg("width_offset") = 0})
torch::arg("width_offset") = 0,
torch::arg("checkpoint_handle") =
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>(nullptr)})
.def(
"set_embedding_rocks_dp_wrapper",
&KVTensorWrapper::set_embedding_rocks_dp_wrapper,
Expand Down
Loading
Loading