From c3a8b99e9169174d8f080132c8f90e12ef788977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Jun 2025 10:44:57 +0300 Subject: [PATCH] recurrent : call balloc split_reset() in init_batch() ggml-ci --- src/llama-memory-recurrent.cpp | 37 +++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 1b1e95d567a6c..e52156bf308b6 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { } llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { - std::vector ubatches; + do { + balloc.split_reset(); - while (true) { - llama_ubatch ubatch; + std::vector ubatches; + while (true) { + llama_ubatch ubatch; - if (embd_all) { - // if all tokens are output, split by sequence - ubatch = balloc.split_seq(n_ubatch); - } else { - ubatch = balloc.split_equal(n_ubatch); + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + ubatch = balloc.split_equal(n_ubatch); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT } - if (ubatch.n_tokens == 0) { + if (!prepare(ubatches)) { break; } - ubatches.push_back(std::move(ubatch)); // NOLINT - } - - if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } + return std::make_unique(this, std::move(ubatches)); + } while (false); - return std::make_unique(this, std::move(ubatches)); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } llama_memory_context_ptr llama_memory_recurrent::init_full() {