1
1
use crate :: flash_attn:: flash_attn_varlen;
2
- use crate :: layers:: { HiddenAct , LayerNorm , Linear } ;
3
- use crate :: models:: { GTEConfig , Model , NTKScaling , PositionEmbeddingType , RopeScaling } ;
2
+ use crate :: layers:: { get_cos_sin , get_inv_freqs , LayerNorm , Linear } ;
3
+ use crate :: models:: { GTEClassificationHead , GTEConfig , Model , PositionEmbeddingType , GTEMLP } ;
4
4
use candle:: { DType , Device , IndexOp , Result , Tensor } ;
5
5
use candle_nn:: { Embedding , Module , VarBuilder } ;
6
+ use candle_rotary:: apply_rotary_inplace;
6
7
use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
7
8
8
9
struct GTEAttention {
@@ -72,7 +73,7 @@ impl GTEAttention {
72
73
let k = qkv. narrow ( 1 , self . num_attention_heads , self . num_attention_heads ) ?;
73
74
let v = qkv. narrow ( 1 , self . num_attention_heads * 2 , self . num_attention_heads ) ?;
74
75
75
- candle_rotary :: apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
76
+ apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
76
77
77
78
let attention = flash_attn_varlen (
78
79
& q,
@@ -93,60 +94,7 @@ impl GTEAttention {
93
94
}
94
95
}
95
96
96
- struct GTEMLP {
97
- up_gate_proj : Linear ,
98
- down_proj : Linear ,
99
-
100
- act : HiddenAct ,
101
- intermediate_size : usize ,
102
-
103
- span : tracing:: Span ,
104
- }
105
-
106
- impl GTEMLP {
107
- pub fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
108
- let intermediate_size = config. intermediate_size ;
109
-
110
- let up_gate_proj_weight = vb
111
- . pp ( "up_gate_proj" )
112
- . get ( ( intermediate_size * 2 , config. hidden_size ) , "weight" ) ?;
113
-
114
- let up_gate_proj = Linear :: new ( up_gate_proj_weight, None , None ) ;
115
-
116
- let down_proj_weight = vb
117
- . pp ( "down_proj" )
118
- . get ( ( config. hidden_size , intermediate_size) , "weight" ) ?;
119
- let down_proj_bias = vb. pp ( "down_proj" ) . get ( config. hidden_size , "bias" ) ?;
120
- let down_proj = Linear :: new ( down_proj_weight, Some ( down_proj_bias) , None ) ;
121
-
122
- Ok ( Self {
123
- up_gate_proj,
124
- down_proj,
125
- intermediate_size,
126
- act : config. hidden_act . clone ( ) ,
127
- span : tracing:: span!( tracing:: Level :: TRACE , "mlp" ) ,
128
- } )
129
- }
130
-
131
- pub fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
132
- let _enter = self . span . enter ( ) ;
133
-
134
- let up_gate_states = self . up_gate_proj . forward ( hidden_states) ?;
135
- let up_states = up_gate_states. narrow ( 1 , 0 , self . intermediate_size ) ?;
136
- let gate_states =
137
- up_gate_states. narrow ( 1 , self . intermediate_size , self . intermediate_size ) ?;
138
-
139
- let gate_states = match self . act {
140
- HiddenAct :: Gelu => gate_states. gelu ( ) ,
141
- HiddenAct :: Relu => gate_states. relu ( ) ,
142
- HiddenAct :: Swiglu => gate_states. silu ( ) ,
143
- } ?;
144
- let r = self . down_proj . forward ( & ( gate_states * up_states) ?) ;
145
- r
146
- }
147
- }
148
-
149
- struct GTELayer {
97
+ pub struct GTELayer {
150
98
attention : GTEAttention ,
151
99
mlp : GTEMLP ,
152
100
attention_layer_norm : LayerNorm ,
@@ -198,58 +146,6 @@ impl GTELayer {
198
146
}
199
147
}
200
148
201
- pub struct GTEClassificationHead {
202
- pooler : Option < Linear > ,
203
- classifier : Linear ,
204
- span : tracing:: Span ,
205
- }
206
-
207
- impl GTEClassificationHead {
208
- #[ allow( dead_code) ]
209
- pub ( crate ) fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
210
- let n_classes = match & config. id2label {
211
- None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
212
- Some ( id2label) => id2label. len ( ) ,
213
- } ;
214
-
215
- let pooler = if let Ok ( pooler_weight) = vb
216
- . pp ( "pooler.dense" )
217
- . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
218
- {
219
- let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
220
- Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
221
- } else {
222
- None
223
- } ;
224
-
225
- let classifier_weight = vb
226
- . pp ( "classifier" )
227
- . get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
228
- let classifier_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
229
- let classifier = Linear :: new ( classifier_weight, Some ( classifier_bias) , None ) ;
230
-
231
- Ok ( Self {
232
- classifier,
233
- pooler,
234
- span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
235
- } )
236
- }
237
-
238
- pub ( crate ) fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
239
- let _enter = self . span . enter ( ) ;
240
-
241
- let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
242
- if let Some ( pooler) = self . pooler . as_ref ( ) {
243
- hidden_states = pooler. forward ( & hidden_states) ?;
244
- hidden_states = hidden_states. tanh ( ) ?;
245
- }
246
-
247
- let hidden_states = self . classifier . forward ( & hidden_states) ?;
248
- let hidden_states = hidden_states. squeeze ( 1 ) ?;
249
- Ok ( hidden_states)
250
- }
251
- }
252
-
253
149
pub struct FlashGTEModel {
254
150
word_embeddings : Embedding ,
255
151
token_type_embeddings : Option < Embedding > ,
@@ -322,24 +218,19 @@ impl FlashGTEModel {
322
218
config. layer_norm_eps ,
323
219
) ?;
324
220
325
- let inv_freqs = if let Some ( RopeScaling :: Ntk ( NTKScaling { factor } ) ) = config. rope_scaling {
326
- let inv_freqs = candle_rotary:: inv_freqs (
327
- layers[ 0 ] . attention . attention_head_size ,
328
- config. rope_theta * factor,
329
- vb. device ( ) ,
330
- ) ?;
331
- let s = factor. powf ( 2.0 / layers[ 0 ] . attention . attention_head_size as f32 ) as f64 ;
332
- inv_freqs / s
333
- } else {
334
- candle_rotary:: inv_freqs (
335
- layers[ 0 ] . attention . attention_head_size ,
336
- config. rope_theta ,
337
- vb. device ( ) ,
338
- )
339
- } ?;
340
-
341
- let ( cos_cache, sin_cache) =
342
- candle_rotary:: cos_sin ( config. max_position_embeddings , & inv_freqs, vb. dtype ( ) ) ?;
221
+ let inv_freqs = get_inv_freqs (
222
+ layers[ 0 ] . attention . attention_head_size ,
223
+ config. rope_theta ,
224
+ vb. device ( ) ,
225
+ config. rope_scaling . as_ref ( ) ,
226
+ ) ?;
227
+
228
+ let ( cos_cache, sin_cache) = get_cos_sin (
229
+ config. max_position_embeddings ,
230
+ & inv_freqs,
231
+ vb. dtype ( ) ,
232
+ false ,
233
+ ) ?;
343
234
344
235
Ok ( Self {
345
236
word_embeddings,
0 commit comments