From 6c1a8951f5b48ccc660626411d882f7aa05aa5bd Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Sun, 25 May 2025 07:16:48 +0800 Subject: [PATCH] tests: add unit tests for llama batch processing functionality --- tests/CMakeLists.txt | 1 + tests/test-llama-batch.cpp | 565 +++++++++++++++++++++++++++++++++++++ 2 files changed, 566 insertions(+) create mode 100644 tests/test-llama-batch.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 083347d188880..067a15abd64e2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -164,6 +164,7 @@ if (NOT GGML_BACKEND_DL) llama_build_and_test(test-quantize-fns.cpp) llama_build_and_test(test-quantize-perf.cpp) llama_build_and_test(test-rope.cpp) + llama_build_and_test(test-llama-batch.cpp) endif() # libmtmd diff --git a/tests/test-llama-batch.cpp b/tests/test-llama-batch.cpp new file mode 100644 index 0000000000000..0ffc181263b5c --- /dev/null +++ b/tests/test-llama-batch.cpp @@ -0,0 +1,565 @@ +#include "../src/llama-batch.h" +#include "../common/common.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +/** + * llama_batch/sbatch/ubatch Test Program + * Tests the basic principles and functionality of batch processing + * Focuses on split_simple operation and state modifications + * + * Data Flow Diagram: + * ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + * │ llama_batch │───▶│ llama_sbatch │───▶│ llama_ubatch │ + * │ (raw input) │ │ (sorted/grouped)│ │ (view/subset) │ + * │ │ │ │ │ │ + * │ token[]: [A,B,C]│ │ seq[]: groups │ │ token: ptr→data │ + * │ pos[]: [0,1,2]│ │ ids[]: [0,1,2] │ │ n_tokens: count │ + * │ seq_id: [0,0,0] │ │ offset: 0 │ │ equal_seqs: T/F │ + * └─────────────────┘ │ length: 3 │ └─────────────────┘ + * └─────────────────┘ + */ + +struct test_scope { + const char * name; + explicit test_scope(const char * name) : name(name) { + std::cout << "\n╔══════════════════════════════════════════════════════════════════════════════════════╗\n"; + std::cout << "║ " << std::left << std::setw(84) << name << " ║\n"; + std::cout << "╚══════════════════════════════════════════════════════════════════════════════════════╝\n"; + } + ~test_scope() { + std::cout << "\n✅ " << name << " Test Completed Successfully\n"; + std::cout << "═══════════════════════════════════════════════════════════════════════════════════════\n\n"; + } +}; + +// Helper function to print batch details +static void print_batch_details(const llama_batch& batch, const std::string& title) { + std::cout << "\n" << title << " Details:\n"; + std::cout << "---------------------------------------------\n"; + std::cout << "Total Tokens: " << batch.n_tokens << "\n"; + + if (batch.token) { + std::cout << "Tokens: "; + for (int i = 0; i < batch.n_tokens; ++i) { + std::cout << batch.token[i] << " "; + } + std::cout << "\n"; + } + + if (batch.pos) { + std::cout << "Positions: "; + for (int i = 0; i < batch.n_tokens; ++i) { + std::cout << batch.pos[i] << " "; + } + std::cout << "\n"; + } + + if (batch.n_seq_id && batch.seq_id) { + std::cout << "Sequence Details:\n"; + for (int i = 0; i < batch.n_tokens; ++i) { + std::cout << " Token[" << i << "]: seq_ids=["; + for (int j = 0; j < batch.n_seq_id[i]; ++j) { + std::cout << batch.seq_id[i][j]; + if (j < batch.n_seq_id[i] - 1) std::cout << ","; + } + std::cout << "]\n"; + } + } + + if (batch.logits) { + std::cout << "Output Flags: "; + for (int i = 0; i < batch.n_tokens; ++i) { + std::cout << (int)batch.logits[i] << " "; + } + std::cout << "\n"; + } + std::cout << "---------------------------------------------\n"; +} + +// Helper function to print sbatch details +static void print_sbatch_details(const llama_sbatch& sbatch, const std::string& title) { + std::cout << "\n" << title << " Details:\n"; + std::cout << "---------------------------------------------\n"; + std::cout << "Total Tokens: " << sbatch.n_tokens << "\n"; + std::cout << "Sequences: " << sbatch.seq.size() << "\n"; + + for (size_t i = 0; i < sbatch.seq.size(); ++i) { + const auto& s = sbatch.seq[i]; + std::cout << "Sequence[" << i << "]: " + << "offset=" << s.offset + << ", length=" << s.length << "\n"; + + if (s.seq_id && s.n_seq_id > 0) { + std::cout << " Sequence IDs: ["; + for (int j = 0; j < s.n_seq_id; ++j) { + std::cout << s.seq_id[j]; + if (j < s.n_seq_id - 1) std::cout << ","; + } + std::cout << "]\n"; + } + } + + std::cout << "Sorted Token Order: "; + for (size_t i = 0; i < sbatch.ids.size(); ++i) { + std::cout << sbatch.ids[i] << " "; + } + std::cout << "\n"; + std::cout << "---------------------------------------------\n"; +} + +// Helper function to print ubatch details +static void print_ubatch_details(const llama_ubatch& ubatch, const std::string& title) { + std::cout << "\n" << title << " Details:\n"; + std::cout << "---------------------------------------------\n"; + std::cout << "Equal Sequences: " << (ubatch.equal_seqs ? "true" : "false") << "\n"; + std::cout << "Total Tokens: " << ubatch.n_tokens << "\n"; + std::cout << "Tokens per Sequence: " << ubatch.n_seq_tokens << "\n"; + std::cout << "Number of Sequences: " << ubatch.n_seqs << "\n"; + + if (ubatch.token) { + std::cout << "Tokens: "; + for (size_t i = 0; i < ubatch.n_tokens; ++i) { + std::cout << ubatch.token[i] << " "; + } + std::cout << "\n"; + } + + if (ubatch.pos) { + std::cout << "Positions: "; + for (size_t i = 0; i < ubatch.n_tokens; ++i) { + std::cout << ubatch.pos[i] << " "; + } + std::cout << "\n"; + } + + if (ubatch.n_seq_id) { + std::cout << "Sequence ID Details: "; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < ubatch.n_seqs; ++i) { + std::cout << ubatch.n_seq_id[i] << " "; + } + } else { + for (size_t i = 0; i < ubatch.n_tokens; ++i) { + std::cout << ubatch.n_seq_id[i] << " "; + } + } + std::cout << "\n"; + } + + if (ubatch.output) { + std::cout << "Output Flags: "; + for (size_t i = 0; i < ubatch.n_tokens; ++i) { + std::cout << (int)ubatch.output[i] << " "; + } + std::cout << "\n"; + } + std::cout << "---------------------------------------------\n"; +} + +// Test 1: Basic Batch Creation and Conversion +static void test_basic_batch_conversion() { + test_scope scope("Basic Batch Creation and Conversion"); + + /* + * Basic Conversion Flow: + * + * llama_batch (raw input): + * ┌─────┬─────┬─────┬─────┬─────┐ + * │ 100 │ 101 │ 102 │ 103 │ 104 │ ← tokens + * │ 0 │ 1 │ 2 │ 3 │ 4 │ ← positions + * │ 0 │ 0 │ 0 │ 0 │ 0 │ ← seq_id + * └─────┴─────┴─────┴─────┴─────┘ + * ↓ + * llama_sbatch (simple_split=true): + * ┌─────────────────────────────────┐ + * │ seq[0]: {n_seq_id=0, offset=0, │ + * │ length=5} │ + * │ ids[]: [0,1,2,3,4] │ + * └─────────────────────────────────┘ + */ + + // Create a simple batch with 5 tokens in one sequence + llama_batch batch = llama_batch_init(10, 0, 2); // max 10 tokens, no embeddings, max 2 seqs + + // Add tokens to sequence 0 + llama_seq_id seq_0 = 0; + common_batch_add(batch, 100, 0, {seq_0}, false); // token 100 at pos 0 + common_batch_add(batch, 101, 1, {seq_0}, false); // token 101 at pos 1 + common_batch_add(batch, 102, 2, {seq_0}, false); // token 102 at pos 2 + common_batch_add(batch, 103, 3, {seq_0}, false); // token 103 at pos 3 + common_batch_add(batch, 104, 4, {seq_0}, true); // token 104 at pos 4, output=true + + print_batch_details(batch, "Original Batch"); + + // Convert to sbatch with simple split mode + llama_sbatch sbatch(batch, 64, true, false); // n_embd=64, simple_split=true, logits_all=false + + print_sbatch_details(sbatch, "Simple Split SBatch"); + + // Verify that simple split creates one sequence with n_seq_id = 0 + GGML_ASSERT(sbatch.seq.size() == 1); + GGML_ASSERT(sbatch.seq[0].n_seq_id == 0); + GGML_ASSERT(sbatch.seq[0].length == 5); + GGML_ASSERT(sbatch.seq[0].offset == 0); + + llama_batch_free(batch); +} + +// Test 2: Testing split_simple Operation and State Modification +static void test_split_simple_modification() { + test_scope scope("Split Simple Operation and State Modification"); + + /* + * split_simple State Modification Visualization: + * + * Initial sbatch state: + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ ← token data + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ▲ ▲ + * offset=0 offset+length=6 + * + * After split_simple(2): + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑consumed↑ ▲ ▲ + * offset=2 offset+length=6 + * + * After split_simple(3): + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─── consumed ────↑ ▲ ▲ + * offset=5 offset+length=6 + * + * Key insight: split_simple "consumes" tokens from the head by advancing offset! + */ + + // Create a batch with 6 tokens + llama_batch batch = llama_batch_init(10, 0, 1); + + llama_seq_id seq_0 = 0; + for (int i = 0; i < 6; ++i) { + // is_logits? + common_batch_add(batch, 200 + i, i, {seq_0}, i == 5); // last token outputs + } + + print_batch_details(batch, "Original Batch (6 tokens)"); + + // Convert to sbatch + llama_sbatch sbatch(batch, 64, true, false); + + print_sbatch_details(sbatch, "Initial SBatch State"); + + std::cout << "\n=== Testing Multiple split_simple Calls ===\n"; + + // First split_simple call - take 2 tokens + std::cout << "\n--- First split_simple(2) ---\n"; + std::cout << "Before split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + /* + * Visual representation of split_simple(2): + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─ extract these 2 ─↑ ↑─ remaining ─↑ + * → ubatch1 → sbatch.seq[0] + */ + + llama_ubatch ubatch1 = sbatch.split_simple(2); + + std::cout << "After split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + print_ubatch_details(ubatch1, "First UBatch (2 tokens)"); + + // Verify the modifications + GGML_ASSERT(sbatch.seq[0].offset == 2); // offset advanced by 2 + GGML_ASSERT(sbatch.seq[0].length == 4); // length reduced by 2 + GGML_ASSERT(sbatch.n_tokens == 4); // total tokens reduced by 2 + GGML_ASSERT(ubatch1.n_tokens == 2); // ubatch contains 2 tokens + + // Second split_simple call - take 3 tokens + std::cout << "\n--- Second split_simple(3) ---\n"; + std::cout << "Before split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + /* + * Visual representation of split_simple(3): + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─consumed─↑ ↑─extract these 3─↑↑─remaining─↑ + * → ubatch2 → sbatch.seq[0] + */ + + llama_ubatch ubatch2 = sbatch.split_simple(3); + + std::cout << "After split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + print_ubatch_details(ubatch2, "Second UBatch (3 tokens)"); + + // Verify the modifications + GGML_ASSERT(sbatch.seq[0].offset == 5); // offset advanced by 3 more + GGML_ASSERT(sbatch.seq[0].length == 1); // length reduced by 3 more + GGML_ASSERT(sbatch.n_tokens == 1); // total tokens reduced by 3 more + GGML_ASSERT(ubatch2.n_tokens == 3); // ubatch contains 3 tokens + + // Third split_simple call - take remaining token + std::cout << "\n--- Third split_simple(10) (should only get 1 token) ---\n"; + std::cout << "Before split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + /* + * Visual representation - requesting more than available: + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─────consumed──────────────↑ ↑only 1↑ + * remaining + */ + + llama_ubatch ubatch3 = sbatch.split_simple(10); // Request more than available + + std::cout << "After split_simple:\n"; + std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n"; + std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n"; + std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n"; + + print_ubatch_details(ubatch3, "Third UBatch (1 token)"); + + // Verify the modifications + GGML_ASSERT(sbatch.seq[0].offset == 6); // offset advanced by 1 more + GGML_ASSERT(sbatch.seq[0].length == 0); // length reduced to 0 + GGML_ASSERT(sbatch.n_tokens == 0); // no more tokens + GGML_ASSERT(ubatch3.n_tokens == 1); // ubatch contains 1 token + + // Fourth split_simple call - should return empty ubatch + std::cout << "\n--- Fourth split_simple(1) (should be empty) ---\n"; + + /* + * Visual representation - nothing left: + * ┌─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ + * └─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─────────all consumed────────────↑ + * offset=6, length=0 + */ + + llama_ubatch ubatch4 = sbatch.split_simple(1); + print_ubatch_details(ubatch4, "Fourth UBatch (empty)"); + + GGML_ASSERT(ubatch4.n_tokens == 0); // no tokens available + + std::cout << "\n✓ All state modifications verified correctly!\n"; + + llama_batch_free(batch); +} + +// Test 3: Multi-Sequence Batch Processing +static void test_multi_sequence_batch() { + test_scope scope("Multi-Sequence Batch Processing"); + + /* + * Multi-Sequence Processing Visualization: + * + * Original batch (mixed sequences): + * ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │ 999 │ + * │seq:0│seq:0│seq:0│seq:1│seq:1│seq:2│0&1 │ + * │pos:0│pos:1│pos:2│pos:0│pos:1│pos:0│pos:10│ + * └─────┴─────┴─────┴─────┴─────┴─────┴─────┘ + * + * After sbatch sorting (complex mode): + * ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 999 │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │ + * │0&1 │seq:0│seq:0│seq:0│seq:1│seq:1│seq:2│ + * │pos:10│pos:0│pos:1│pos:2│pos:0│pos:1│pos:0│ + * └─────┴─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑ ↑─────seq 0──────↑ ↑─seq 1─↑ ↑seq2↑ + * shared (sorted by pos) + * prompt + * + * Simple split mode treats everything as one sequence: + * ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐ + * │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │ 999 │ + * │ │ │ │ │ │ │ │ + * └─────┴─────┴─────┴─────┴─────┴─────┴─────┘ + * ↑─────────all treated as seq_id=0──────────↑ + */ + + // Create a batch with multiple sequences + llama_batch batch = llama_batch_init(20, 0, 3); + + llama_seq_id seq_0 = 0; + llama_seq_id seq_1 = 1; + llama_seq_id seq_2 = 2; + + // Add tokens to different sequences + common_batch_add(batch, 300, 0, {seq_0}, false); // seq_0: pos 0 + common_batch_add(batch, 301, 1, {seq_0}, false); // seq_0: pos 1 + common_batch_add(batch, 302, 2, {seq_0}, true); // seq_0: pos 2, output + + common_batch_add(batch, 400, 0, {seq_1}, false); // seq_1: pos 0 + common_batch_add(batch, 401, 1, {seq_1}, true); // seq_1: pos 1, output + + common_batch_add(batch, 500, 0, {seq_2}, true); // seq_2: pos 0, output + + // Add a shared prompt token (belongs to multiple sequences) + common_batch_add(batch, 999, 10, {seq_0, seq_1}, false); // shared between seq_0 and seq_1 + + print_batch_details(batch, "Multi-Sequence Batch"); + + // Convert to sbatch with complex split mode (simple_split=false) + llama_sbatch sbatch_complex(batch, 64, false, false); + + print_sbatch_details(sbatch_complex, "Complex SBatch (sorted by seq_id)"); + + std::cout << "\n=== Testing split_equal and split_seq ===\n"; + + /* + * split_equal strategy: + * - Processes sequences by equal-length batches + * - Shared prompts processed first (highest priority) + * - Equal length sequences grouped together + * + * split_seq strategy: + * - Processes one sequence at a time + * - Takes from the end of sequence list + * - Good for sequential processing + */ + + // Test split_equal + llama_ubatch ubatch_equal = sbatch_complex.split_equal(10); + print_ubatch_details(ubatch_equal, "Split Equal Result"); + + // Test split_seq + llama_ubatch ubatch_seq = sbatch_complex.split_seq(5); + print_ubatch_details(ubatch_seq, "Split Seq Result"); + + // Compare with simple split approach + llama_sbatch sbatch_simple(batch, 64, true, false); + print_sbatch_details(sbatch_simple, "Simple SBatch"); + + llama_ubatch ubatch_simple = sbatch_simple.split_simple(10); + print_ubatch_details(ubatch_simple, "Simple Split Result"); + + llama_batch_free(batch); +} + +// Test 4: Edge Cases and Error Conditions +static void test_edge_cases() { + test_scope scope("Edge Cases and Error Conditions"); + + /* + * Edge Case Testing: + * + * Empty batch: + * ┌─┐ + * │ │ ← no tokens + * └─┘ + * + * Single token batch: + * ┌─────┐ + * │ 777 │ ← one token + * └─────┘ + * + * After split: + * ┌─┐ + * │ │ ← empty sbatch + * └─┘ + */ + + // Test empty batch + llama_batch empty_batch = llama_batch_init(5, 0, 1); + // Don't add any tokens + + print_batch_details(empty_batch, "Empty Batch"); + + llama_sbatch empty_sbatch(empty_batch, 64, true, false); + print_sbatch_details(empty_sbatch, "Empty SBatch"); + + llama_ubatch empty_ubatch = empty_sbatch.split_simple(5); + print_ubatch_details(empty_ubatch, "Empty UBatch from split_simple"); + + GGML_ASSERT(empty_ubatch.n_tokens == 0); + GGML_ASSERT(empty_sbatch.seq.empty()); + + // Test single token batch + llama_batch single_batch = llama_batch_init(5, 0, 1); + common_batch_add(single_batch, 777, 0, {0}, true); + + print_batch_details(single_batch, "Single Token Batch"); + + llama_sbatch single_sbatch(single_batch, 64, true, false); + print_sbatch_details(single_sbatch, "Single Token SBatch"); + + llama_ubatch single_ubatch = single_sbatch.split_simple(1); + print_ubatch_details(single_ubatch, "Single Token UBatch"); + + GGML_ASSERT(single_ubatch.n_tokens == 1); + GGML_ASSERT(single_ubatch.token[0] == 777); + + // After split, sbatch should be empty + llama_ubatch post_split_ubatch = single_sbatch.split_simple(1); + GGML_ASSERT(post_split_ubatch.n_tokens == 0); + + llama_batch_free(empty_batch); + llama_batch_free(single_batch); +} + +int main(int argc, char** argv) { + std::cout << "llama_batch/sbatch/ubatch Test Program\n"; + std::cout << "=====================================\n"; + std::cout << "Testing batch processing principles and split_simple modifications\n"; + + /* + * Overall Test Architecture: + * + * ┌─────────────────────────┐ + * │ Input Validation │ + * │ (test_basic_batch_*) │ + * └───────────┬─────────────┘ + * ▼ + * ┌─────────────────────────┐ + * │ Core Functionality │ + * │(test_split_simple_*) │ ← Main focus: state modification + * └───────────┬─────────────┘ + * ▼ + * ┌─────────────────────────┐ + * │ Complex Scenarios │ + * │(test_multi_sequence_*) │ + * └───────────┬─────────────┘ + * ▼ + * ┌─────────────────────────┐ + * │ Edge Cases & │ + * │ Data Integrity │ + * └─────────────────────────┘ + */ + + test_basic_batch_conversion(); + test_split_simple_modification(); + test_multi_sequence_batch(); + test_edge_cases(); + + return 0; +} \ No newline at end of file