Skip to content

Pass the updated embeddings to EmbeddingKVDB #4210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void stream_cuda(
const Tensor& indices,
const Tensor& weights,
const Tensor& count,
bool blocking_tensor_copy = true) {
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
}

void stream_sync_cuda() {
return impl_->stream_sync_cuda();
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}
Expand Down Expand Up @@ -95,6 +107,10 @@ static auto embedding_parameter_server_wrapper =
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("stream_cuda", &EmbeddingParameterServerWrapper::stream_cuda)
.def(
"stream_sync_cuda",
&EmbeddingParameterServerWrapper::stream_sync_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
.def("compact", &EmbeddingParameterServerWrapper::compact)
.def("flush", &EmbeddingParameterServerWrapper::flush)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void stream_cuda(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy = true) {
return impl_->stream_cuda(indices, weights, count, blocking_tensor_copy);
}

void stream_sync_cuda() {
return impl_->stream_sync_cuda();
}

void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->get_cuda(indices, weights, count);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ EmbeddingKVDB::~EmbeddingKVDB() {
}
#ifdef FBGEMM_FBCODE
if (enable_raw_embedding_streaming_) {
weights_stream_thread_->join();
join_stream_tensor_copy_thread();
join_weights_stream_thread();
}
#endif
}
Expand Down Expand Up @@ -302,6 +303,39 @@ folly::coro::Task<void> EmbeddingKVDB::tensor_stream(
}
co_return;
}

void EmbeddingKVDB::copy_and_enqueue_stream_tensors(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) {
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::copy_and_enqueue_stream_tensors ##");
auto stream_item =
tensor_copy(indices, weights, count, kv_db::RocksdbWriteMode::STREAM);
weights_to_stream_queue_.enqueue(stream_item);
rec->record.end();
}

void EmbeddingKVDB::join_stream_tensor_copy_thread() {
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::join_stream_tensor_copy_thread ##");
if (stream_tensor_copy_thread_ != nullptr &&
stream_tensor_copy_thread_->joinable()) {
stream_tensor_copy_thread_->join();
}
rec->record.end();
}

void EmbeddingKVDB::join_weights_stream_thread() {
if (weights_stream_thread_ != nullptr && weights_stream_thread_->joinable()) {
stop_ = true;
weights_stream_thread_->join();
}
}

uint64_t EmbeddingKVDB::get_weights_to_stream_queue_size() {
return weights_to_stream_queue_.size();
}
#endif

void EmbeddingKVDB::update_cache_and_storage(
Expand Down Expand Up @@ -403,6 +437,45 @@ void EmbeddingKVDB::set_cuda(
rec->record.end();
}

void EmbeddingKVDB::stream_cuda(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy) {
#ifdef FBGEMM_FBCODE
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::stream_cuda ##");
check_tensor_type_consistency(indices, weights);
// take reference to self to avoid lifetime issues.
auto self = shared_from_this();
std::function<void()>* functor = new std::function<void()>(
[=]() { self->stream(indices, weights, count, blocking_tensor_copy); });
AT_CUDA_CHECK(cudaStreamAddCallback(
at::cuda::getCurrentCUDAStream(),
kv_db_utils::cuda_callback_func,
functor,
0));
rec->record.end();
#endif
}

void EmbeddingKVDB::stream_sync_cuda() {
#ifdef FBGEMM_FBCODE
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::stream_sync_cuda ##");
// take reference to self to avoid lifetime issues.
auto self = shared_from_this();
std::function<void()>* functor = new std::function<void()>(
[=]() { self->join_stream_tensor_copy_thread(); });
AT_CUDA_CHECK(cudaStreamAddCallback(
at::cuda::getCurrentCUDAStream(),
kv_db_utils::cuda_callback_func,
functor,
0));
rec->record.end();
#endif
}

std::vector<double> EmbeddingKVDB::get_l2cache_perf(
const int64_t step,
const int64_t interval) {
Expand Down Expand Up @@ -472,6 +545,9 @@ void EmbeddingKVDB::set(
return;
}
CHECK_EQ(max_D_, weights.size(1));

auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::set_callback ##");
// defer the L2 cache/rocksdb update to the background thread as it could
// be parallelized with other cuda kernels, as long as all updates are
// finished before the next L2 cache lookup
Expand All @@ -487,6 +563,7 @@ void EmbeddingKVDB::set(
} else {
update_cache_and_storage(indices, weights, count, write_mode);
}
rec->record.end();
}

void EmbeddingKVDB::get(
Expand All @@ -500,6 +577,8 @@ void EmbeddingKVDB::get(
<< num_lookups;
return;
}
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::get_callback ##");
CHECK_GE(max_D_, weights.size(1));
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
wait_util_filling_work_done();
Expand Down Expand Up @@ -560,6 +639,33 @@ void EmbeddingKVDB::get(
get_kv_db_async(indices, weights, count).wait();
}
get_total_duration_ += facebook::WallClockUtil::NowInUsecFast() - start_ts;
rec->record.end();
}

void EmbeddingKVDB::stream(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy) {
if (!enable_raw_embedding_streaming_) {
return;
}
auto rec = torch::autograd::profiler::record_function_enter_new(
"## EmbeddingKVDB::stream_callback ##");
if (blocking_tensor_copy) {
copy_and_enqueue_stream_tensors(indices, weights, count);
return;
}
// Make sure the previous thread is done before starting a new one
join_stream_tensor_copy_thread();
// Cuda dispatches the host callbacks all in the same CPU thread. But the
// callbacks don't need to be serialized.
// So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA
// can continue executing other host callbacks, eg. get/evict.
stream_tensor_copy_thread_ = std::make_unique<std::thread>([=, this]() {
copy_and_enqueue_stream_tensors(indices, weights, count);
});
rec->record.end();
}

std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,19 @@ class CacheContext {
/// BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for
/// embedding update on bwd path
///
/// All the L2 cache filling above will potentially trigger rocksdb write once
/// L2 cache is full
/// All the L2 cache filling above will
/// potentially trigger rocksdb write once L2 cache is full
///
/// STREAM: placeholder for raw embedding streaming requests, it doesn't
/// directly interact with L2 and rocksDB
///
/// Additionally we will do ssd io on L2 flush
enum RocksdbWriteMode {
FWD_ROCKSDB_READ = 0,
FWD_L1_EVICTION = 1,
BWD_L1_CNFLCT_MISS_WRITE_BACK = 2,
FLUSH = 3,
STREAM = 4,
};

/// @ingroup embedding-ssd
Expand Down Expand Up @@ -196,6 +200,33 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const at::Tensor& count,
int64_t sleep_ms = 0);

/// Stream out non-negative elements in <indices> and its paired embeddings
/// from <weights> for the first <count> elements in the tensor.
/// It spins up a thread that will copy all 3 tensors to CPU and inject them
/// into the background queue which will be picked up by another set of thread
/// pools for streaming out to the thrift server (co-located on same host
/// now).
///
/// This is used in cuda stream callback, which doesn't require to be
/// serialized with other callbacks, thus a separate thread is used to
/// maximize the overlapping with other callbacks.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights The 2D tensor that each row(embeddings) is paired up with
/// relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
/// @param blocking_tensor_copy whether to copy the tensors to be streamed in
/// a blocking manner
///
/// @return None
void stream(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy = true);

/// storage tier counterpart of function get()
virtual folly::SemiFuture<std::vector<folly::Unit>> get_kv_db_async(
const at::Tensor& indices,
Expand Down Expand Up @@ -234,6 +265,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const int64_t timestep,
const bool is_bwd = false);

void stream_cuda(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool blocking_tensor_copy = true);

void stream_sync_cuda();

/// export internally collected L2 performance metrics out
///
/// @param step the training step that caller side wants to report the stats
Expand Down Expand Up @@ -314,6 +353,28 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
folly::coro::Task<void> tensor_stream(
const at::Tensor& indices,
const at::Tensor& weights);
/*
* Copy the indices, weights and count tensors and enqueue them for
* asynchronous stream.
*/
void copy_and_enqueue_stream_tensors(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count);

/*
* Join the stream tensor copy thread, make sure the thread is properly
* finished before creating new.
*/
void join_stream_tensor_copy_thread();

/*
* FOR TESTING: Join the weight stream thread, make sure the thread is
* properly finished for destruction and testing.
*/
void join_weights_stream_thread();
// FOR TESTING: get queue size.
uint64_t get_weights_to_stream_queue_size();
#endif

private:
Expand Down Expand Up @@ -455,7 +516,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
std::vector<int64_t> table_offsets_;
at::Tensor table_sizes_;
std::unique_ptr<std::thread> weights_stream_thread_;
folly::USPSCQueue<QueueItem, true> weights_to_stream_queue_;
folly::UMPSCQueue<QueueItem, true> weights_to_stream_queue_;
std::unique_ptr<std::thread> stream_tensor_copy_thread_;
}; // class EmbeddingKVDB

} // namespace kv_db
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,17 @@ static auto embedding_rocks_db_wrapper =
torch::arg("timestep"),
torch::arg("is_bwd") = false,
})
.def(
"stream_cuda",
&EmbeddingRocksDBWrapper::stream_cuda,
"",
{
torch::arg("indices"),
torch::arg("weights"),
torch::arg("count"),
torch::arg("blocking_tensor_copy"),
})
.def("stream_sync_cuda", &EmbeddingRocksDBWrapper::stream_sync_cuda)
.def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda)
.def("compact", &EmbeddingRocksDBWrapper::compact)
.def("flush", &EmbeddingRocksDBWrapper::flush)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
case kv_db::RocksdbWriteMode::FLUSH:
flush_write_dur_ += duration;
break;
case kv_db::RocksdbWriteMode::STREAM:
break;
}
#endif
return folly::collect(futures);
Expand Down
Loading
Loading