@@ -306,14 +306,7 @@ impl FlashBertModel {
306
306
let pool = Pool :: Cls ;
307
307
308
308
let classifier: Box < dyn ClassificationHead + Send > =
309
- if config. model_type == Some ( "bert" . to_string ( ) ) {
310
- Box :: new ( BertClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?)
311
- } else {
312
- Box :: new ( RobertaClassificationHead :: load (
313
- vb. pp ( "classifier" ) ,
314
- config,
315
- ) ?)
316
- } ;
309
+ Box :: new ( BertClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?) ;
317
310
( pool, Some ( classifier) )
318
311
}
319
312
ModelType :: Embedding ( pool) => ( pool, None ) ,
@@ -325,16 +318,78 @@ impl FlashBertModel {
325
318
) {
326
319
( Ok ( embeddings) , Ok ( encoder) ) => ( embeddings, encoder) ,
327
320
( Err ( err) , _) | ( _, Err ( err) ) => {
328
- let model_type = config. model_type . clone ( ) . unwrap_or ( "bert" . to_string ( ) ) ;
321
+ if let ( Ok ( embeddings) , Ok ( encoder) ) = (
322
+ BertEmbeddings :: load ( vb. pp ( "bert.embeddings" . to_string ( ) ) , config) ,
323
+ BertEncoder :: load ( vb. pp ( "bert.encoder" . to_string ( ) ) , config) ,
324
+ ) {
325
+ ( embeddings, encoder)
326
+ } else {
327
+ return Err ( err) ;
328
+ }
329
+ }
330
+ } ;
331
+
332
+ Ok ( Self {
333
+ embeddings,
334
+ encoder,
335
+ pool,
336
+ classifier,
337
+ device : vb. device ( ) . clone ( ) ,
338
+ span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
339
+ } )
340
+ }
341
+
342
+ pub fn load_roberta (
343
+ vb : VarBuilder ,
344
+ config : & BertConfig ,
345
+ model_type : ModelType ,
346
+ ) -> Result < Self > {
347
+ match vb. device ( ) {
348
+ Device :: Cuda ( _) => { }
349
+ _ => candle:: bail!( "FlashBert requires Cuda" ) ,
350
+ }
351
+
352
+ if vb. dtype ( ) != DType :: F16 {
353
+ candle:: bail!( "FlashBert requires DType::F16" )
354
+ }
355
+
356
+ // Check position embedding type
357
+ if config. position_embedding_type != PositionEmbeddingType :: Absolute {
358
+ candle:: bail!( "FlashBert only supports absolute position embeddings" )
359
+ }
360
+
361
+ let ( pool, classifier) = match model_type {
362
+ // Classifier models always use CLS pooling
363
+ ModelType :: Classifier => {
364
+ let pool = Pool :: Cls ;
365
+
366
+ let classifier: Box < dyn ClassificationHead + Send > = Box :: new (
367
+ RobertaClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?,
368
+ ) ;
369
+ ( pool, Some ( classifier) )
370
+ }
371
+ ModelType :: Embedding ( pool) => ( pool, None ) ,
372
+ } ;
329
373
374
+ let ( embeddings, encoder) = match (
375
+ BertEmbeddings :: load ( vb. pp ( "embeddings" ) , config) ,
376
+ BertEncoder :: load ( vb. pp ( "encoder" ) , config) ,
377
+ ) {
378
+ ( Ok ( embeddings) , Ok ( encoder) ) => ( embeddings, encoder) ,
379
+ ( Err ( err) , _) | ( _, Err ( err) ) => {
330
380
if let ( Ok ( embeddings) , Ok ( encoder) ) = (
331
- BertEmbeddings :: load ( vb. pp ( format ! ( "{model_type}.embeddings" ) ) , config) ,
332
- BertEncoder :: load ( vb. pp ( format ! ( "{model_type}.encoder" ) ) , config) ,
381
+ BertEmbeddings :: load ( vb. pp ( "roberta.embeddings" . to_string ( ) ) , config) ,
382
+ BertEncoder :: load ( vb. pp ( "roberta.encoder" . to_string ( ) ) , config) ,
383
+ ) {
384
+ ( embeddings, encoder)
385
+ } else if let ( Ok ( embeddings) , Ok ( encoder) ) = (
386
+ BertEmbeddings :: load ( vb. pp ( "xlm-roberta.embeddings" . to_string ( ) ) , config) ,
387
+ BertEncoder :: load ( vb. pp ( "xlm-roberta.encoder" . to_string ( ) ) , config) ,
333
388
) {
334
389
( embeddings, encoder)
335
390
} else if let ( Ok ( embeddings) , Ok ( encoder) ) = (
336
- BertEmbeddings :: load ( vb. pp ( "roberta .embeddings" ) , config) ,
337
- BertEncoder :: load ( vb. pp ( "roberta .encoder" ) , config) ,
391
+ BertEmbeddings :: load ( vb. pp ( "camembert .embeddings" . to_string ( ) ) , config) ,
392
+ BertEncoder :: load ( vb. pp ( "camembert .encoder" . to_string ( ) ) , config) ,
338
393
) {
339
394
( embeddings, encoder)
340
395
} else {
0 commit comments