@@ -63,27 +63,24 @@ __device__ void copyChunkedHiddenStates(T const* srcPtr, T* dstPtr, int const nu
63
63
}
64
64
65
65
template <typename T>
66
- __global__ void mtpPrepareDrafterInputsKernel (int const numMTPModules, int const curMTPLayerIdx, int const batchSize ,
67
- int const numContextRequest , int const hiddenSize , int const * inputIds, int const * seqLens ,
68
- T ** const mtpPastHiddenStatesPtrs, int ** const mtpPastTokensPtrs, T* const previousLayerHiddenStates ,
69
- int * const previousLayerDraftTokens, int * returnInputIds, T* returnHiddenStates)
66
+ __global__ void mtpPrepareDrafterInputsKernel (int const numMTPModules, int const numContextRequest ,
67
+ int const hiddenSize , int const * inputIds , int const * seqLens, T** const mtpPastHiddenStatesPtrs ,
68
+ int ** const mtpPastTokensPtrs, T* const hiddenStates, int const * acceptedTokens, int const * numAcceptedTokens ,
69
+ int * returnInputIds, T* returnHiddenStates)
70
70
{
71
71
/*
72
72
In a batch of request: context request (at the beginning) + generation requests
73
73
numGenerationRequest = batchSize - numContextRequest
74
74
75
75
inputIds: [N]
76
- - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
77
- + 1)
78
- - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
76
+ - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1)
79
77
seqLens: [batchSize]
80
78
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
81
79
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
82
- previousLayerHiddenStates: [N, hiddenSize]
83
- - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
84
- + 1) (from target model)
85
- - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
86
- previousLayerDraftTokens: [batchSize], the draft tokens generated by the previous layer
80
+ hiddenStates: [N, hiddenSize]
81
+ - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) (from target model)
82
+ acceptedTokens: [batchSize, numMTPModules + 1]
83
+ numAcceptedTokens: [batchSize]
87
84
returnInputIds: [N]
88
85
- N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
89
86
returnHiddenStates: [N, hiddenSize]
@@ -94,6 +91,7 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
94
91
95
92
T const * curMTPPastHiddenStatesPtr = mtpPastHiddenStatesPtrs[bid];
96
93
int const * curMTPPastTokensPtr = mtpPastTokensPtrs[bid];
94
+ int const * curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1 );
97
95
98
96
int curSeqLen = seqLens[bid];
99
97
@@ -117,63 +115,44 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
117
115
}
118
116
119
117
int const * curInputIdsPtr = inputIds + inputIdsStartOffset;
120
- T const * curPreviousLayerHiddenStates = previousLayerHiddenStates + inputIdsStartOffset * hiddenSize;
118
+ T const * curHiddenStates = hiddenStates + inputIdsStartOffset * hiddenSize;
121
119
122
120
int * curReturnInputIdsPtr = returnInputIds + returnInputIdsStartOffset;
123
121
T* curReturnHiddenStatesIdsPtr = returnHiddenStates + returnInputIdsStartOffset * hiddenSize;
124
122
125
123
// // main logic
126
-
127
- if (curMTPLayerIdx == 0 )
124
+ if (bid < numContextRequest)
128
125
{
129
- if (bid < numContextRequest)
126
+ // context requests
127
+ if (tid == 0 )
130
128
{
131
- // context requests
132
- if (tid == 0 )
129
+ // 1) For the new inputIds
130
+ for ( int ii = 0 ; ii < curSeqLen - 1 ; ii++ )
133
131
{
134
- // 1) For the new inputIds
135
- for (int ii = 0 ; ii < curSeqLen - 1 ; ii++)
136
- {
137
- curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1 ]; // +1 because of offset 1, prompt[1:]
138
- }
139
- // Append the latest golden token, i.e., the last one in the past tokens list
140
- curReturnInputIdsPtr[curSeqLen - 1 ] = curMTPPastTokensPtr[numMTPModules - 1 ];
132
+ curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1 ]; // +1 because of offset 1, prompt[1:]
141
133
}
142
-
143
- // 2) For the new past hidden states
144
- copyChunkedHiddenStates (curPreviousLayerHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
134
+ // Append the latest golden token, i.e., the first one in the accepted tokens list
135
+ curReturnInputIdsPtr[curSeqLen - 1 ] = curAcceptedTokensPtr[0 ];
145
136
}
146
- else
147
- {
148
- // generation requests
149
- if (tid == 0 )
150
- {
151
- // 1) For the new inputIds
152
- for (int ii = 0 ; ii < numMTPModules; ii++)
153
- {
154
- curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii];
155
- }
156
- }
157
137
158
- // 2) For the new past hidden states
159
- copyChunkedHiddenStates (curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
160
- }
138
+ // 2) For the new past hidden states
139
+ copyChunkedHiddenStates (curHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
161
140
}
162
- else // For curMTPLayerIdx > 0
141
+ else
163
142
{
143
+ // generation requests
164
144
if (tid == 0 )
165
145
{
166
146
// 1) For the new inputIds
167
- int numPastTokens = (bid < numContextRequest) ? curSeqLen : numMTPModules;
168
- for (int ii = 0 ; ii < numPastTokens; ii++)
147
+ for (int ii = 0 ; ii < numMTPModules - 1 ; ii++)
169
148
{
170
- curReturnInputIdsPtr[ii] = curInputIdsPtr [ii + 1 ];
149
+ curReturnInputIdsPtr[ii] = curMTPPastTokensPtr [ii + 1 ];
171
150
}
172
- curReturnInputIdsPtr[numPastTokens - 1 ] = previousLayerDraftTokens[ bid];
151
+ curReturnInputIdsPtr[numMTPModules - 1 ] = curAcceptedTokensPtr[numAcceptedTokens[ bid] - 1 ];
173
152
}
174
153
175
154
// 2) For the new past hidden states
176
- // Directly use previous layer's output hidden states
155
+ copyChunkedHiddenStates (curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
177
156
}
178
157
}
179
158
@@ -185,10 +164,10 @@ void invokeMTPPrepareDrafterInputs(MTPPrepareDrafterInputsParam& params, cudaStr
185
164
params.hiddenSize * sizeof (T) % 16 == 0 ); // Which is because we will use float4 to copy the hidden states.
186
165
187
166
mtpPrepareDrafterInputsKernel<T><<<params.batchSize, BLOCK_SIZE, 0 , stream>>> (params.numMTPModules ,
188
- params.curMTPLayerIdx , params.batchSize , params.numContextRequest , params.hiddenSize , params. inputIds ,
189
- params. seqLens , reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ), params.mtpPastTokensPtrs ,
190
- reinterpret_cast <T*>(params.previousLayerHiddenStates ), params.previousLayerDraftTokens , params.returnInputIds ,
191
- reinterpret_cast <T*>(params.returnHiddenStates ));
167
+ params.numContextRequest , params.hiddenSize , params.inputIds , params.seqLens ,
168
+ reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ), params.mtpPastTokensPtrs ,
169
+ reinterpret_cast <T*>(params.hiddenStates ), params.acceptedTokens , params.numAcceptedTokens ,
170
+ params. returnInputIds , reinterpret_cast <T*>(params.returnHiddenStates ));
192
171
193
172
sync_check_cuda_error (stream);
194
173
}
@@ -362,7 +341,7 @@ template void invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(
362
341
template <typename T>
363
342
__global__ void mtpUpdateHiddenStatesKernel (int const numMTPModules, int const batchSize, int const numContextRequest,
364
343
int const hiddenSize, int const * inputIds, int const * seqLens, T* targetModelHiddenStates,
365
- T** mtpPastHiddenStatesPtrs, int ** mtpPastTokensPtrs, int const * numAcceptedTokens, int const * acceptedTokens )
344
+ T** mtpPastHiddenStatesPtrs, int ** mtpPastTokensPtrs, int const * numAcceptedTokens)
366
345
{
367
346
/*
368
347
In a batch of request: context request (at the beginning) + generation requests
@@ -374,7 +353,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
374
353
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
375
354
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
376
355
numAcceptedTokens: [batchSize]
377
- acceptedTokens: [batchSize][numMTPModules + 1], flatten
378
356
*/
379
357
380
358
int const bid = static_cast <int >(blockIdx .x ); // Each block is responsible for a request.
@@ -395,7 +373,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
395
373
396
374
auto curInputIdsPtr = inputIds + inputIdsStartOffset;
397
375
auto curTargetModelHiddenStatesPtr = targetModelHiddenStates + inputIdsStartOffset * hiddenSize;
398
- auto curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1 );
399
376
400
377
// Update MTP tokens
401
378
// Just use one thread to execute this copy
@@ -405,12 +382,10 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
405
382
{
406
383
// Context request
407
384
// Copy the end of prompt tokens
408
- for (int ii = 0 ; ii < numMTPModules - 1 ; ii++)
385
+ for (int ii = 0 ; ii < numMTPModules; ii++)
409
386
{
410
- curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + 1 + ii];
387
+ curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + ii];
411
388
}
412
- // Copy the new generated golden token
413
- curMTPPastTokensPtr[numMTPModules - 1 ] = curAcceptedTokensPtr[0 ];
414
389
}
415
390
else
416
391
{
@@ -424,7 +399,7 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
424
399
int acceptedTokenStartIdx = max (0 , curAcceptedLen - numMTPModules);
425
400
for (; ii < numMTPModules; ii++, acceptedTokenStartIdx++)
426
401
{
427
- curMTPPastTokensPtr[ii] = curAcceptedTokensPtr [acceptedTokenStartIdx];
402
+ curMTPPastTokensPtr[ii] = curInputIdsPtr [acceptedTokenStartIdx];
428
403
}
429
404
}
430
405
}
@@ -463,7 +438,7 @@ void invokeMTPUpdateHiddenStates(MTPUpdateHiddenStatesParam& params, cudaStream_
463
438
mtpUpdateHiddenStatesKernel<T><<<params.batchSize, BLOCK_SIZE, 0 , stream>>> (params.numMTPModules , params.batchSize ,
464
439
params.numContextRequest , params.hiddenSize , params.inputIds , params.seqLens ,
465
440
reinterpret_cast <T*>(params.targetModelHiddenStates ), reinterpret_cast <T**>(params.mtpPastHiddenStatesPtrs ),
466
- params.mtpPastTokensPtrs , params.numAcceptedTokens , params. acceptedTokens );
441
+ params.mtpPastTokensPtrs , params.numAcceptedTokens );
467
442
sync_check_cuda_error (stream);
468
443
}
469
444
0 commit comments