Skip to content

Commit f812b4e

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Creating Checkpointhandle and connecting to EmbeddingRocksDB
Summary: X-link: facebookresearch/FBGEMM#1300 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: CheckpointHandle: It is an entity is responsible for storing the details of the rocksDB Checkpoint. It consists of the file paths to the checkpoint of all shards. When creating KVTensorWrappers, we use the same checkpointHandle object 1. Creating CheckpointHandle class 2. Adding the definition for CheckpointHandle constructor 3. Adding CRUD functions for CheckpointHandle 4. Connecting CheckpointHandle to EmbeddingRocksDB entity Differential Revision: D75489827
1 parent d598938 commit f812b4e

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "embedding_rocksdb_wrapper.h"
1919
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
2020
#include "fbgemm_gpu/utils/ops_utils.h"
21-
21+
#include "rocksdb/utilities/checkpoint.h"
2222
using namespace at;
2323
using namespace ssd;
2424
using namespace kv_mem;
@@ -293,6 +293,37 @@ snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
293293
return shard_snapshots_[shard];
294294
}
295295

296+
CheckpointHandle::CheckpointHandle(
297+
EmbeddingRocksDB* db,
298+
const std::string& tbe_uuid,
299+
const std::string& ckpt_uuid,
300+
const std::string& base_path,
301+
bool use_default_ssd_path)
302+
: db_(db), ckpt_uuid_(ckpt_uuid) {
303+
auto num_shards = db->num_shards();
304+
CHECK_GT(num_shards, 0);
305+
shard_checkpoints_.reserve(num_shards);
306+
for (auto shard = 0; shard < num_shards; ++shard) {
307+
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
308+
base_path, shard, tbe_uuid, use_default_ssd_path);
309+
auto checkpoint_shard_dir =
310+
kv_db_utils::get_rocksdb_checkpoint_dir(shard, rocksdb_path);
311+
kv_db_utils::create_dir(checkpoint_shard_dir);
312+
rocksdb::Checkpoint* checkpoint = nullptr;
313+
rocksdb::Status s =
314+
rocksdb::Checkpoint::Create(db->dbs_[shard].get(), &checkpoint);
315+
CHECK(s.ok()) << "ERROR: Checkpoint init for tbe_uuid " << tbe_uuid
316+
<< ", db shard " << shard << " failed, " << s.code() << ", "
317+
<< s.ToString();
318+
std::string checkpoint_shard_path = checkpoint_shard_dir + "/" + ckpt_uuid_;
319+
s = checkpoint->CreateCheckpoint(checkpoint_shard_path);
320+
CHECK(s.ok()) << "ERROR: Checkpoint creation for tbe_uuid " << tbe_uuid
321+
<< ", db shard " << shard << " failed, " << s.code() << ", "
322+
<< s.ToString();
323+
shard_checkpoints_.push_back(checkpoint_shard_path);
324+
}
325+
}
326+
296327
EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
297328
const SnapshotHandle* handle,
298329
std::shared_ptr<EmbeddingRocksDB> db)

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ class SnapshotHandle {
6565
std::vector<snapshot_ptr_t> shard_snapshots_;
6666
}; // class SnapshotHandle
6767

68+
using checkpoint_path = std::string;
69+
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
70+
class CheckpointHandle {
71+
public:
72+
explicit CheckpointHandle(
73+
EmbeddingRocksDB* db,
74+
const std::string& tbe_uuid,
75+
const std::string& ckpt_uuid,
76+
const std::string& base_path,
77+
bool use_default_ssd_path);
78+
79+
private:
80+
friend class EmbeddingRocksDB;
81+
82+
EmbeddingRocksDB* db_;
83+
std::string ckpt_uuid_;
84+
std::vector<checkpoint_path> shard_checkpoints_;
85+
}; // class CheckpointHandle
86+
6887
/// @ingroup embedding-ssd
6988
///
7089
/// @brief An implementation of EmbeddingKVDB for RocksDB
@@ -488,6 +507,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
488507
return snapshots_.find(snapshot_handle) != snapshots_.end();
489508
}
490509

510+
bool is_valid_checkpoint(const std::string ckpt_uuid) const {
511+
return checkpoints_.find(ckpt_uuid) != checkpoints_.end();
512+
}
513+
491514
int64_t get_snapshot_count() const {
492515
return snapshots_.size();
493516
}
@@ -505,6 +528,58 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
505528
return handlePtr;
506529
}
507530

531+
std::string create_checkpoint(int64_t global_step) {
532+
const auto num_ckpts = checkpoints_.size();
533+
if (num_ckpts > 0) {
534+
std::cerr << "WARNING: rocksdb create_checkpoint found " << num_ckpts
535+
<< " other checkpoints_" << std::endl;
536+
}
537+
538+
// Removing the CHECK needs extra caution, basically we need to find a way
539+
// to consistently link a KVTensor to a checkpoint. Ideally by uuid, but
540+
// create_checkpoint is called earlier than state_dict and KVTensor is
541+
// created inside state_dict, additionally we don't want to pass the uuid
542+
// into state_dict, unless we link it manually on the caller side inside the
543+
// state_dict loop, it would be random linking. For example, Case1 and Case2
544+
// could call state_dict in the same batch at any order with random times,
545+
// it is impossible to correctly attach each checkpoint to each KVTensor.
546+
CHECK(
547+
global_step_to_ckpt_uuid_.find(global_step) ==
548+
global_step_to_ckpt_uuid_.end())
549+
<< "multiple rdb checkpoint in one global step isn't supported right now";
550+
auto ckpt_uuid = facebook::strings::generateUUID();
551+
auto handle = std::make_unique<CheckpointHandle>(
552+
this, tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
553+
checkpoints_[ckpt_uuid] = std::move(handle);
554+
global_step_to_ckpt_uuid_[global_step] = ckpt_uuid;
555+
return ckpt_uuid;
556+
}
557+
558+
std::optional<std::string> get_active_checkpoint_uuid(int64_t global_step) {
559+
if (global_step_to_ckpt_uuid_.find(global_step) !=
560+
global_step_to_ckpt_uuid_.end()) {
561+
return std::make_optional<std::string>(
562+
global_step_to_ckpt_uuid_[global_step]);
563+
}
564+
return std::nullopt;
565+
}
566+
567+
void release_checkpoint(const std::string ckpt_uuid) {
568+
CHECK(is_valid_checkpoint(ckpt_uuid));
569+
LOG(INFO) << "Checkpoint " << ckpt_uuid << " released";
570+
checkpoints_.erase(ckpt_uuid);
571+
// sweep through global_step_to_ckpt_uuid_, it should be small
572+
int64_t glb_step_to_purge = -1;
573+
for (const auto& [global_step, uuid] : global_step_to_ckpt_uuid_) {
574+
if (ckpt_uuid == uuid) {
575+
glb_step_to_purge = global_step;
576+
break;
577+
}
578+
}
579+
CHECK(glb_step_to_purge != -1) << "There must be a rdb ckpt uuid to purge";
580+
global_step_to_ckpt_uuid_.erase(glb_step_to_purge);
581+
}
582+
508583
void release_snapshot(const SnapshotHandle* snapshot_handle) {
509584
CHECK(is_valid_snapshot(snapshot_handle));
510585
LOG(INFO) << "Snapshot " << snapshot_handle << " released";
@@ -1117,6 +1192,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
11171192
}
11181193

11191194
friend class SnapshotHandle;
1195+
friend class CheckpointHandle;
11201196

11211197
std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
11221198
std::vector<std::unique_ptr<Initializer>> initializers_;
@@ -1149,6 +1225,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
11491225
std::string tbe_uuid_;
11501226
std::string path_;
11511227
bool use_default_ssd_path_;
1228+
1229+
// rocksdb checkpoint is used to create an on disk database to support
1230+
// cross process read-only access, currently this is only needed for publish
1231+
// in async checkpoint, SSD read is handled in the same proc
1232+
std::unordered_map<std::string, std::unique_ptr<CheckpointHandle>>
1233+
checkpoints_;
1234+
// this is used to link KVTensor to the corresponding rdb checkpoint by global
1235+
// step reasons are shown below
1236+
// 1. rdb checkpoint is only needed during publish not checkpoint, therefore
1237+
// only some state_dict() calls need it
1238+
// 2. we dont' want to add new argument in state_dict to add complexity
1239+
// 3. therefore create_checkpoint is exposed up to the publish component
1240+
// 4. publish will call create_checkpoint() then state_dict()
1241+
// 5. need a way to link rdb ckpt and KVTensor(from state_dict()) together so
1242+
// that caller doesn't need to do the link themselves
1243+
std::unordered_map<int64_t, std::string> global_step_to_ckpt_uuid_;
11521244
}; // class EmbeddingRocksDB
11531245

11541246
} // namespace ssd

0 commit comments

Comments
 (0)