@@ -65,6 +65,25 @@ class SnapshotHandle {
65
65
std::vector<snapshot_ptr_t > shard_snapshots_;
66
66
}; // class SnapshotHandle
67
67
68
+ using checkpoint_path = std::string;
69
+ // @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
70
+ class CheckpointHandle {
71
+ public:
72
+ explicit CheckpointHandle (
73
+ EmbeddingRocksDB* db,
74
+ const std::string& tbe_uuid,
75
+ const std::string& ckpt_uuid,
76
+ const std::string& base_path,
77
+ bool use_default_ssd_path);
78
+
79
+ private:
80
+ friend class EmbeddingRocksDB ;
81
+
82
+ EmbeddingRocksDB* db_;
83
+ std::string ckpt_uuid_;
84
+ std::vector<checkpoint_path> shard_checkpoints_;
85
+ }; // class CheckpointHandle
86
+
68
87
// / @ingroup embedding-ssd
69
88
// /
70
89
// / @brief An implementation of EmbeddingKVDB for RocksDB
@@ -488,6 +507,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
488
507
return snapshots_.find (snapshot_handle) != snapshots_.end ();
489
508
}
490
509
510
+ bool is_valid_checkpoint (const std::string ckpt_uuid) const {
511
+ return checkpoints_.find (ckpt_uuid) != checkpoints_.end ();
512
+ }
513
+
491
514
int64_t get_snapshot_count () const {
492
515
return snapshots_.size ();
493
516
}
@@ -505,6 +528,58 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
505
528
return handlePtr;
506
529
}
507
530
531
+ std::string create_checkpoint (int64_t global_step) {
532
+ const auto num_ckpts = checkpoints_.size ();
533
+ if (num_ckpts > 0 ) {
534
+ std::cerr << " WARNING: rocksdb create_checkpoint found " << num_ckpts
535
+ << " other checkpoints_" << std::endl;
536
+ }
537
+
538
+ // Removing the CHECK needs extra caution, basically we need to find a way
539
+ // to consistently link a KVTensor to a checkpoint. Ideally by uuid, but
540
+ // create_checkpoint is called earlier than state_dict and KVTensor is
541
+ // created inside state_dict, additionally we don't want to pass the uuid
542
+ // into state_dict, unless we link it manually on the caller side inside the
543
+ // state_dict loop, it would be random linking. For example, Case1 and Case2
544
+ // could call state_dict in the same batch at any order with random times,
545
+ // it is impossible to correctly attach each checkpoint to each KVTensor.
546
+ CHECK (
547
+ global_step_to_ckpt_uuid_.find (global_step) ==
548
+ global_step_to_ckpt_uuid_.end ())
549
+ << " multiple rdb checkpoint in one global step isn't supported right now" ;
550
+ auto ckpt_uuid = facebook::strings::generateUUID ();
551
+ auto handle = std::make_unique<CheckpointHandle>(
552
+ this , tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
553
+ checkpoints_[ckpt_uuid] = std::move (handle);
554
+ global_step_to_ckpt_uuid_[global_step] = ckpt_uuid;
555
+ return ckpt_uuid;
556
+ }
557
+
558
+ std::optional<std::string> get_active_checkpoint_uuid (int64_t global_step) {
559
+ if (global_step_to_ckpt_uuid_.find (global_step) !=
560
+ global_step_to_ckpt_uuid_.end ()) {
561
+ return std::make_optional<std::string>(
562
+ global_step_to_ckpt_uuid_[global_step]);
563
+ }
564
+ return std::nullopt;
565
+ }
566
+
567
+ void release_checkpoint (const std::string ckpt_uuid) {
568
+ CHECK (is_valid_checkpoint (ckpt_uuid));
569
+ LOG (INFO) << " Checkpoint " << ckpt_uuid << " released" ;
570
+ checkpoints_.erase (ckpt_uuid);
571
+ // sweep through global_step_to_ckpt_uuid_, it should be small
572
+ int64_t glb_step_to_purge = -1 ;
573
+ for (const auto & [global_step, uuid] : global_step_to_ckpt_uuid_) {
574
+ if (ckpt_uuid == uuid) {
575
+ glb_step_to_purge = global_step;
576
+ break ;
577
+ }
578
+ }
579
+ CHECK (glb_step_to_purge != -1 ) << " There must be a rdb ckpt uuid to purge" ;
580
+ global_step_to_ckpt_uuid_.erase (glb_step_to_purge);
581
+ }
582
+
508
583
void release_snapshot (const SnapshotHandle* snapshot_handle) {
509
584
CHECK (is_valid_snapshot (snapshot_handle));
510
585
LOG (INFO) << " Snapshot " << snapshot_handle << " released" ;
@@ -1117,6 +1192,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
1117
1192
}
1118
1193
1119
1194
friend class SnapshotHandle ;
1195
+ friend class CheckpointHandle ;
1120
1196
1121
1197
std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
1122
1198
std::vector<std::unique_ptr<Initializer>> initializers_;
@@ -1149,6 +1225,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
1149
1225
std::string tbe_uuid_;
1150
1226
std::string path_;
1151
1227
bool use_default_ssd_path_;
1228
+
1229
+ // rocksdb checkpoint is used to create an on disk database to support
1230
+ // cross process read-only access, currently this is only needed for publish
1231
+ // in async checkpoint, SSD read is handled in the same proc
1232
+ std::unordered_map<std::string, std::unique_ptr<CheckpointHandle>>
1233
+ checkpoints_;
1234
+ // this is used to link KVTensor to the corresponding rdb checkpoint by global
1235
+ // step reasons are shown below
1236
+ // 1. rdb checkpoint is only needed during publish not checkpoint, therefore
1237
+ // only some state_dict() calls need it
1238
+ // 2. we dont' want to add new argument in state_dict to add complexity
1239
+ // 3. therefore create_checkpoint is exposed up to the publish component
1240
+ // 4. publish will call create_checkpoint() then state_dict()
1241
+ // 5. need a way to link rdb ckpt and KVTensor(from state_dict()) together so
1242
+ // that caller doesn't need to do the link themselves
1243
+ std::unordered_map<int64_t , std::string> global_step_to_ckpt_uuid_;
1152
1244
}; // class EmbeddingRocksDB
1153
1245
1154
1246
} // namespace ssd
0 commit comments