From eda2c4199dd810218bf0316fd154790db7e193a6 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 18 Apr 2025 15:31:23 +0900 Subject: [PATCH 1/5] fix: weight name of classification head --- backends/candle/src/models/modernbert.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index c32cc5b3..6207211d 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -412,24 +412,27 @@ pub struct ModernBertClassificationHead { impl ModernBertClassificationHead { pub(crate) fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { let dense_weight = vb - .pp("dense") + .pp("head.dense") .get((config.hidden_size, config.hidden_size), "weight")?; - let dense_bias = vb.pp("dense").get(config.hidden_size, "bias").ok(); + let dense_bias = vb.pp("head.dense").get(config.hidden_size, "bias").ok(); let dense = Linear::new( dense_weight, dense_bias, Some(config.classifier_activation.clone()), ); - let norm = - LayerNormNoBias::load(vb.pp("norm"), config.hidden_size, config.norm_eps as f32)?; + let norm = LayerNormNoBias::load( + vb.pp("head.norm"), + config.hidden_size, + config.norm_eps as f32, + )?; - let classifier_weight = vb.pp("dense").get( + let classifier_weight = vb.pp("classifier").get( (config.num_labels.unwrap_or(1), config.hidden_size), "weight", )?; let classifier_bias = vb - .pp("dense") + .pp("classifier") .get(config.num_labels.unwrap_or(1), "bias") .ok(); let classifier = Linear::new(classifier_weight, classifier_bias, None); From 9479cfa1d1afc9a0a1d1aa56b8c4658288aee014 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 18 Apr 2025 15:31:35 +0900 Subject: [PATCH 2/5] add: test case for modernbert classifier --- backends/candle/tests/test_modernbert.rs | 43 +++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/backends/candle/tests/test_modernbert.rs b/backends/candle/tests/test_modernbert.rs index 625658cd..fe75692e 100644 --- a/backends/candle/tests/test_modernbert.rs +++ b/backends/candle/tests/test_modernbert.rs @@ -2,7 +2,9 @@ mod common; use crate::common::{sort_embeddings, SnapshotEmbeddings}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{ + batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher, SnapshotScores, +}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -135,3 +137,42 @@ fn test_mini_pooled_raw() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_modernbert_classification() -> Result<()> { + let model_root = download_artifacts("Alibaba-NLP/gte-reranker-modernbert-base", None).unwrap(); + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer + .encode( + ( + "PrimeTime is a timing signoff tool", + "PrimeTime can perform most accurate timing analysis", + ), + true, + ) + .unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + let predictions_single = SnapshotScores::from(predictions); + + let matcher = relative_matcher(); + insta::assert_yaml_snapshot!( + "modernbert_classification_single", + predictions_single, + &matcher + ); + + Ok(()) +} From ae75ed53a295f51b31aae96a90b7361e93d87733 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 18 Apr 2025 15:46:27 +0900 Subject: [PATCH 3/5] update: test --- ...test_modernbert__modernbert_classification_single.snap | 5 +++++ backends/candle/tests/test_modernbert.rs | 8 +------- 2 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap diff --git a/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap new file mode 100644 index 00000000..20ef4d37 --- /dev/null +++ b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap @@ -0,0 +1,5 @@ +--- +source: backends/candle/tests/test_modernbert.rs +expression: predictions_single +--- +- - 2.2585099 diff --git a/backends/candle/tests/test_modernbert.rs b/backends/candle/tests/test_modernbert.rs index fe75692e..2ed288c0 100644 --- a/backends/candle/tests/test_modernbert.rs +++ b/backends/candle/tests/test_modernbert.rs @@ -148,13 +148,7 @@ fn test_modernbert_classification() -> Result<()> { let input_single = batch( vec![tokenizer - .encode( - ( - "PrimeTime is a timing signoff tool", - "PrimeTime can perform most accurate timing analysis", - ), - true, - ) + .encode(("What is Deep Learning?", "Deep Learning is not..."), true) .unwrap()], [0].to_vec(), vec![], From 7334fc064d3a02b2a2b3b9eae83a97748267daac Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 18 Apr 2025 15:55:13 +0900 Subject: [PATCH 4/5] docs: reranker models --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 520ad8d0..4bf90763 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Below are some examples of the currently supported models: | Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | | Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | | Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) | +| Re-Ranking | ModernBert | [Alibaba-NLP/gte-reranker-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base) | | Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | ### Docker From 67b7cff0fe3dc3df77c2bd03badb1fbc85ee8389 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 18 Apr 2025 17:09:19 +0900 Subject: [PATCH 5/5] update: ModernBertClassificationHead --- backends/candle/src/models/modernbert.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index 6207211d..12a9072a 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -414,10 +414,9 @@ impl ModernBertClassificationHead { let dense_weight = vb .pp("head.dense") .get((config.hidden_size, config.hidden_size), "weight")?; - let dense_bias = vb.pp("head.dense").get(config.hidden_size, "bias").ok(); let dense = Linear::new( dense_weight, - dense_bias, + None, Some(config.classifier_activation.clone()), ); @@ -433,9 +432,8 @@ impl ModernBertClassificationHead { )?; let classifier_bias = vb .pp("classifier") - .get(config.num_labels.unwrap_or(1), "bias") - .ok(); - let classifier = Linear::new(classifier_weight, classifier_bias, None); + .get(config.num_labels.unwrap_or(1), "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); Ok(Self { dense,