@@ -2450,12 +2450,48 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi
2450
2450
};
2451
2451
2452
2452
std::map<SizeType32, float > windowSizeToShare;
2453
- // NOTE: Righteously, blocks allocated should be proportional with
2454
- // regard to window size. Currently, we are first allocating identical
2455
- // number of blocks for all layers to achieve identical performance.
2456
- for (auto const & [windowSize, _] : windowSizeToLayers)
2453
+ if (auto envStr = std::getenv (" TRTLLM_WINDOW_SIZE_SHARES" ))
2457
2454
{
2458
- windowSizeToShare[windowSize] = 1 .0f / windowSizeToLayers.size ();
2455
+ std::stringstream ss (envStr);
2456
+ std::vector<float > shares;
2457
+ float share;
2458
+ while (ss >> share)
2459
+ {
2460
+ shares.push_back (share);
2461
+ if (ss.peek () == ' ,' )
2462
+ ss.ignore ();
2463
+ }
2464
+
2465
+ TLLM_CHECK_WITH_INFO (shares.size () == windowSizeToLayers.size (),
2466
+ " Number of shares in TRTLLM_WINDOW_SIZE_SHARES (%ld) must match number of window sizes (%ld)" ,
2467
+ shares.size (), windowSizeToLayers.size ());
2468
+ float sumShares = 0 .0f ;
2469
+ for (auto s : shares)
2470
+ {
2471
+ TLLM_CHECK_WITH_INFO (0 .0f <= s && s <= 1 .0f , " Shares must be in value range [0,1], got %f" , s);
2472
+ sumShares += s;
2473
+ }
2474
+ TLLM_CHECK_WITH_INFO (sumShares > 0 .0f , " Sum of shares must be > 0." );
2475
+ // Normalize shares to 1.0
2476
+ for (auto & s : shares)
2477
+ {
2478
+ s /= sumShares;
2479
+ }
2480
+ size_t i = 0 ;
2481
+ for (auto const & [windowSize, _] : windowSizeToLayers)
2482
+ {
2483
+ windowSizeToShare[windowSize] = shares[i++];
2484
+ }
2485
+ }
2486
+ else
2487
+ {
2488
+ // NOTE: Righteously, blocks allocated should be proportional with
2489
+ // regard to window size. Currently, we are first allocating identical
2490
+ // number of blocks for all layers to achieve identical performance.
2491
+ for (auto const & [windowSize, _] : windowSizeToLayers)
2492
+ {
2493
+ windowSizeToShare[windowSize] = 1 .0f / windowSizeToLayers.size ();
2494
+ }
2459
2495
}
2460
2496
2461
2497
std::vector<SizeType32> blocksPrimary;
0 commit comments