Skip to content

Creating Checkpointhandle and connecting to EmbeddingRocksDB #4222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@

#include <ATen/ATen.h>
#include <folly/hash/Hash.h>
#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>
#include <filesystem>
#include <optional>

/// @defgroup embedding-ssd Embedding SSD Operators
///

namespace kv_db_utils {

#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
#endif

/// @ingroup embedding-ssd
///
/// @brief hash function used for SSD L2 cache and rocksdb sharding algorithm
Expand Down Expand Up @@ -65,4 +70,94 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
std::optional<int64_t> bucket_size,
std::optional<int64_t> total_num_buckets);

/// @ingroup embedding-ssd
///
/// @brief default way to generate rocksdb path based on a user provided
/// base_path the file hierarchy will be
/// <base_path><ssd_idx>/<tbe_uuid> for default SSD mount
/// <base_path>/<tbe_uuid> for user provided base path
///
/// @param base_path the base path for all the rocksdb shards tied to one
/// TBE/EmbeddingRocksDB
/// @param db_shard_id the rocksdb shard index, this is used to determine which
/// SSD to use
/// @param tbe_uuid unique identifier per TBE at the lifetime of a training job
/// @param default_path whether the base_path is default SSD mount or
/// user-provided
///
/// @return the base path to that rocksdb shard
inline std::string get_rocksdb_path(
const std::string& base_path,
int db_shard_id,
const std::string& tbe_uuid,
bool default_path) {
if (default_path) {
int ssd_drive_idx = db_shard_id % num_ssd_drives;
std::string ssd_idx_tbe_id_str =
std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid;
return base_path + ssd_idx_tbe_id_str;
} else {
return base_path + std::string("/") + tbe_uuid;
}
}

/// @ingroup embedding-ssd
///
/// @brief generate rocksdb shard path, based on rocksdb_path
/// the file hierarchy will be
/// <rocksdb_shard_path>/shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the rocksdb shard path
inline std::string get_rocksdb_shard_path(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/shard_") + std::to_string(db_shard_id);
}

/// @ingroup embedding-ssd
///
/// @brief generate a directory to hold rocksdb checkpoint for a particular
/// rocksdb shard path the file hierarchy will be
/// <rocksdb_shard_path>/checkpoint_shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the directory that holds rocksdb checkpoints for one rocksdb shard
inline std::string get_rocksdb_checkpoint_dir(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/checkpoint_shard_") +
std::to_string(db_shard_id);
}

inline void create_dir(const std::string& dir_path) {
try {
std::filesystem::path fs_path(dir_path);
bool res = std::filesystem::create_directories(fs_path);
if (!res) {
LOG(ERROR) << "dir: " << dir_path << " already exists";
}
} catch (const std::exception& e) {
LOG(ERROR) << "Error creating directory: " << e.what();
}
}

inline void remove_dir(const std::string& path) {
if (std::filesystem::exists(path)) {
try {
if (std::filesystem::is_directory(path)) {
std::filesystem::remove_all(path);
} else {
std::filesystem::remove(path);
}
} catch (const std::filesystem::filesystem_error& e) {
LOG(ERROR) << "Error removing path: " << path
<< ", exception:" << e.what();
}
}
}
}; // namespace kv_db_utils
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "embedding_rocksdb_wrapper.h"
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
#include "fbgemm_gpu/utils/ops_utils.h"

#include "rocksdb/utilities/checkpoint.h"
using namespace at;
using namespace ssd;
using namespace kv_mem;
Expand Down Expand Up @@ -293,6 +293,37 @@ snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
return shard_snapshots_[shard];
}

CheckpointHandle::CheckpointHandle(
EmbeddingRocksDB* db,
const std::string& tbe_uuid,
const std::string& ckpt_uuid,
const std::string& base_path,
bool use_default_ssd_path)
: db_(db), ckpt_uuid_(ckpt_uuid) {
auto num_shards = db->num_shards();
CHECK_GT(num_shards, 0);
shard_checkpoints_.reserve(num_shards);
for (auto shard = 0; shard < num_shards; ++shard) {
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
base_path, shard, tbe_uuid, use_default_ssd_path);
auto checkpoint_shard_dir =
kv_db_utils::get_rocksdb_checkpoint_dir(shard, rocksdb_path);
kv_db_utils::create_dir(checkpoint_shard_dir);
rocksdb::Checkpoint* checkpoint = nullptr;
rocksdb::Status s =
rocksdb::Checkpoint::Create(db->dbs_[shard].get(), &checkpoint);
CHECK(s.ok()) << "ERROR: Checkpoint init for tbe_uuid " << tbe_uuid
<< ", db shard " << shard << " failed, " << s.code() << ", "
<< s.ToString();
std::string checkpoint_shard_path = checkpoint_shard_dir + "/" + ckpt_uuid_;
s = checkpoint->CreateCheckpoint(checkpoint_shard_path);
CHECK(s.ok()) << "ERROR: Checkpoint creation for tbe_uuid " << tbe_uuid
<< ", db shard " << shard << " failed, " << s.code() << ", "
<< s.ToString();
shard_checkpoints_.push_back(checkpoint_shard_path);
}
}

EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
const SnapshotHandle* handle,
std::shared_ptr<EmbeddingRocksDB> db)
Expand Down
134 changes: 117 additions & 17 deletions fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ namespace ssd {

using namespace at;

#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
const std::string ssd_mount_point = "/data00_nvidia";
const size_t base_port = 136000;
#endif

// mem usage propertiese
// -- block cache usage
Expand Down Expand Up @@ -69,6 +65,25 @@ class SnapshotHandle {
std::vector<snapshot_ptr_t> shard_snapshots_;
}; // class SnapshotHandle

using checkpoint_path = std::string;
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
class CheckpointHandle {
public:
explicit CheckpointHandle(
EmbeddingRocksDB* db,
const std::string& tbe_uuid,
const std::string& ckpt_uuid,
const std::string& base_path,
bool use_default_ssd_path);

private:
friend class EmbeddingRocksDB;

EmbeddingRocksDB* db_;
std::string ckpt_uuid_;
std::vector<checkpoint_path> shard_checkpoints_;
}; // class CheckpointHandle

/// @ingroup embedding-ssd
///
/// @brief An implementation of EmbeddingKVDB for RocksDB
Expand Down Expand Up @@ -322,23 +337,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
auto db_monitor_options = facebook::fb_rocksdb::DBMonitorOptions();
db_monitor_options.fb303Prefix = "tbe_metrics";

std::string tbe_uuid = "";
tbe_uuid_ = facebook::strings::generateUUID();
use_default_ssd_path_ = !use_passed_in_path;
if (!use_passed_in_path) {
path = ssd_mount_point;
tbe_uuid = facebook::strings::generateUUID();
path_ = std::move(ssd_mount_point);
} else {
path_ = std::move(path);
}
std::string all_shards_path;
#endif
for (auto i = 0; i < num_shards; ++i) {
#ifdef FBGEMM_FBCODE
int ssd_drive_idx = i % num_ssd_drives;
std::string ssd_idx_tbe_id_str = "";
if (!use_passed_in_path) {
ssd_idx_tbe_id_str =
std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid;
}
auto shard_path =
path + ssd_idx_tbe_id_str + std::string("_shard") + std::to_string(i);
used_path += shard_path + ", ";
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
path_, i, tbe_uuid_, !use_passed_in_path);
auto shard_path = kv_db_utils::get_rocksdb_shard_path(i, rocksdb_path);
kv_db_utils::create_dir(shard_path);
all_shards_path += shard_path + ", ";
#else
auto shard_path = path + std::string("/shard_") + std::to_string(i);
#endif
Expand Down Expand Up @@ -371,7 +385,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
dbs_.emplace_back(db);
}
#ifdef FBGEMM_FBCODE
LOG(INFO) << "TBE actual used_path: " << used_path;
LOG(INFO) << "TBE uuid: " << tbe_uuid_
<< ", rocksdb shards paths: " << all_shards_path;
#endif
}

Expand Down Expand Up @@ -494,6 +509,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
return snapshots_.find(snapshot_handle) != snapshots_.end();
}

bool is_valid_checkpoint(const std::string ckpt_uuid) const {
return checkpoints_.find(ckpt_uuid) != checkpoints_.end();
}

int64_t get_snapshot_count() const {
return snapshots_.size();
}
Expand All @@ -511,6 +530,57 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
return handlePtr;
}

std::string create_checkpoint(int64_t global_step) {
const auto num_ckpts = checkpoints_.size();
if (num_ckpts > 0) {
std::cerr << "WARNING: rocksdb create_checkpoint found " << num_ckpts
<< " other checkpoints_" << std::endl;
}

// If the global step already has a checkpoint handler registered, at the
// time create_checkpoint is call, we assume the prev ckpt hanlder has
// fullfilled its job already, thus it is ok to replace it with the new rdb
// checkpoint for next use cases within the same global step
if (global_step_to_ckpt_uuid_.find(global_step) !=
global_step_to_ckpt_uuid_.end()) {
LOG(WARNING)
<< "multiple rdb checkpoint in one global step are being created, "
"removing the prev rdb ckpt, please make sure it has fullfilled "
"its use case, e.g. checkpoint and publish";
}
auto ckpt_uuid = facebook::strings::generateUUID();
auto handle = std::make_unique<CheckpointHandle>(
this, tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
checkpoints_[ckpt_uuid] = std::move(handle);
global_step_to_ckpt_uuid_[global_step] = ckpt_uuid;
return ckpt_uuid;
}

std::optional<std::string> get_active_checkpoint_uuid(int64_t global_step) {
if (global_step_to_ckpt_uuid_.find(global_step) !=
global_step_to_ckpt_uuid_.end()) {
return std::make_optional<std::string>(
global_step_to_ckpt_uuid_[global_step]);
}
return std::nullopt;
}

void release_checkpoint(const std::string ckpt_uuid) {
CHECK_EQ(is_valid_checkpoint(ckpt_uuid), true);
LOG(INFO) << "Checkpoint " << ckpt_uuid << " released";
checkpoints_.erase(ckpt_uuid);
// sweep through global_step_to_ckpt_uuid_, it should be small
int64_t glb_step_to_purge = -1;
for (const auto& [global_step, uuid] : global_step_to_ckpt_uuid_) {
if (ckpt_uuid == uuid) {
glb_step_to_purge = global_step;
break;
}
}
CHECK_NE(glb_step_to_purge, -1) << "There must be a rdb ckpt uuid to purge";
global_step_to_ckpt_uuid_.erase(glb_step_to_purge);
}

void release_snapshot(const SnapshotHandle* snapshot_handle) {
CHECK(is_valid_snapshot(snapshot_handle));
LOG(INFO) << "Snapshot " << snapshot_handle << " released";
Expand Down Expand Up @@ -1123,6 +1193,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
}

friend class SnapshotHandle;
friend class CheckpointHandle;

std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
std::vector<std::unique_ptr<Initializer>> initializers_;
Expand Down Expand Up @@ -1152,6 +1223,35 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
int64_t elem_size_;
std::vector<int64_t> sub_table_dims_;
std::vector<int64_t> sub_table_hash_cumsum_;
std::string tbe_uuid_;
std::string path_;
bool use_default_ssd_path_;

// rocksdb checkpoint is used to create an on disk database to support
// cross process read-only access
std::unordered_map<std::string, std::unique_ptr<CheckpointHandle>>
checkpoints_;
// this is used for KVTensor rdb checkpoint linking by global
// step, reasons are shown below
// 1. rdb checkpoint is created at most twice, for publish and checkpoint
// separately, if they happen on the same train iteration. We can not create
// rdb checkpoint freely because the lifecycle of rdb checkpoint is controlled
// on the component side
//
// 2. publish tends to call state_dict() multiple times to get model FQNs, and
// it is not recommended to modify state_dict signature, thus there is no way
// for the TBE backend to tell which state_dict calls is for weight accessing.
// state_dict() returns KVTensorWrapper to the trainer side, which will be
// consumed by the downstream componenet, e.g. checkpoint and publish, we want
// to link the rdb checkpoint with KVTensorWrapper
//
// 3. therefore we need to way to linked the created rdb checkpoint with
// KVTensorWrapper, and potentially we could have multiple rdb
// checkpoint from different iteration(this is less likely, especially if we
// don't copy KVTensorWrapper to a separate python thread which extends the
// rdb checkpoint handler lifetime). But just in case, we created a global
// step -> rdb checkpoint mapping
std::unordered_map<int64_t, std::string> global_step_to_ckpt_uuid_;
}; // class EmbeddingRocksDB

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ TEST(RocksDbEmbeddingCacheTest, TestPutAndGet) {
-0.01, // uniform_init_lower,
0.01, // uniform_init_upper,
32, // row_storage_bitwidth = 32,
0 // cache_size = 0
0, // cache_size = 0
true // use_passed_in_path
);

auto write_indices =
Expand Down
Loading