@@ -117,29 +117,40 @@ def prepare(self):
117
117
# hidden state space before the target model forward.
118
118
start_idx = 0
119
119
if not self .is_draft_model :
120
- for req_id , seq_len in zip (self .request_ids , self .seq_lens ):
121
- slot_id = self .eagle3_resource_manager .slot_manager .get_slot (
122
- req_id )
123
- self .eagle3_resource_manager .start_indices [slot_id ] = start_idx
124
- start_idx += seq_len
120
+ if self .request_ids is not None and self .seq_lens is not None :
121
+ for req_id , seq_len in zip (self .request_ids , self .seq_lens ):
122
+ slot_id = self .eagle3_resource_manager .slot_manager .get_slot (
123
+ req_id
124
+ ) if self .eagle3_resource_manager is not None else None
125
+ if self .eagle3_resource_manager is not None and slot_id is not None :
126
+ self .eagle3_resource_manager .start_indices [
127
+ slot_id ] = start_idx
128
+ start_idx += seq_len
125
129
# Prepare hidden states gather ids
126
130
hidden_states_read_indices = []
127
131
hidden_states_write_indices = []
128
- for req_id , seq_len in zip (self .request_ids , self .seq_lens ):
129
- slot_id = self .eagle3_resource_manager .slot_manager .get_slot (req_id )
130
- start_idx = self .eagle3_resource_manager .start_indices [slot_id ]
131
- # If this is the first draft or the target model forward, we need to
132
- # read/write all of the hidden states, otherwise, only read the last token
133
- if is_first_draft or not self .is_draft_model :
134
- hidden_states_read_indices .extend (
135
- list (range (start_idx , start_idx + seq_len )))
136
- hidden_states_write_indices .extend (
137
- list (range (start_idx , start_idx + seq_len )))
138
- else :
139
- old_seq_len = self .eagle3_resource_manager .seq_lens [slot_id ]
140
- hidden_states_read_indices .append (start_idx + old_seq_len - 1 )
141
- hidden_states_write_indices .append (start_idx + seq_len - 1 )
142
- self .eagle3_resource_manager .seq_lens [slot_id ] = seq_len
132
+ if self .request_ids is not None and self .seq_lens is not None :
133
+ for req_id , seq_len in zip (self .request_ids , self .seq_lens ):
134
+ if self .eagle3_resource_manager is not None :
135
+ slot_id = self .eagle3_resource_manager .slot_manager .get_slot (
136
+ req_id )
137
+ start_idx = self .eagle3_resource_manager .start_indices [
138
+ slot_id ]
139
+ # If this is the first draft or the target model forward, we need to
140
+ # read/write all of the hidden states, otherwise, only read the last token
141
+ if is_first_draft or not self .is_draft_model :
142
+ hidden_states_read_indices .extend (
143
+ list (range (start_idx , start_idx + seq_len )))
144
+ hidden_states_write_indices .extend (
145
+ list (range (start_idx , start_idx + seq_len )))
146
+ else :
147
+ old_seq_len = self .eagle3_resource_manager .seq_lens [
148
+ slot_id ]
149
+ hidden_states_read_indices .append (start_idx +
150
+ old_seq_len - 1 )
151
+ hidden_states_write_indices .append (start_idx + seq_len -
152
+ 1 )
153
+ self .eagle3_resource_manager .seq_lens [slot_id ] = seq_len
143
154
# Prepare hidden states gather ids
144
155
self .hidden_states_read_indices_host = torch .tensor (
145
156
hidden_states_read_indices , dtype = torch .long , pin_memory = True )
0 commit comments