From a2ca54dd09bca6a5bbad9bb0e2c052d469fd9fcb Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 3 May 2025 23:47:31 +0900 Subject: [PATCH 1/6] fix: GTEClassificationHead --- backends/candle/src/models/gte.rs | 40 ++++++++++++------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 55a225a4..b2b377cf 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -193,7 +193,6 @@ impl GTEMLP { let up_gate_proj_weight = vb .pp("up_gate_proj") .get((intermediate_size * 2, config.hidden_size), "weight")?; - let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); let down_proj_weight = vb @@ -216,16 +215,12 @@ impl GTEMLP { let up_gate_states = self.up_gate_proj.forward(hidden_states)?; let up_states = up_gate_states.narrow(D::Minus1, 0, self.intermediate_size)?; - let gate_states = - up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate = + up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + let gate = self.act.forward(&gate)?; - self.down_proj.forward(&(gate_states * up_states)?) + self.down_proj.forward(&(gate * up_states)?) } } @@ -282,28 +277,23 @@ impl GTELayer { } pub struct GTEClassificationHead { - pooler: Option, + pooler: Linear, classifier: Linear, span: tracing::Span, } impl GTEClassificationHead { - #[allow(dead_code)] pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { let n_classes = match &config.id2label { None => candle::bail!("`id2label` must be set for classifier models"), Some(id2label) => id2label.len(), }; - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; + let pooler_weight = vb + .pp("new.pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight")?; + let pooler_bias = vb.pp("new.pooler.dense").get(config.hidden_size, "bias")?; + let pooler = Linear::new(pooler_weight, Some(pooler_bias), None); let classifier_weight = vb .pp("classifier") @@ -321,14 +311,14 @@ impl GTEClassificationHead { pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); - let mut hidden_states = hidden_states.unsqueeze(1)?; - if let Some(pooler) = self.pooler.as_ref() { - hidden_states = pooler.forward(&hidden_states)?; - hidden_states = hidden_states.tanh()?; - } + let hidden_states = hidden_states.unsqueeze(1)?; + + let hidden_states = self.pooler.forward(&hidden_states)?; + let hidden_states = hidden_states.tanh()?; let hidden_states = self.classifier.forward(&hidden_states)?; let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) } } From 273078f8113159c52d565569b90f0e22879e7133 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 3 May 2025 23:47:58 +0900 Subject: [PATCH 2/6] add: test_gte_classification --- backends/candle/tests/test_gte.rs | 33 +++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/backends/candle/tests/test_gte.rs b/backends/candle/tests/test_gte.rs index b8fe3892..6a9e3e3b 100644 --- a/backends/candle/tests/test_gte.rs +++ b/backends/candle/tests/test_gte.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -137,3 +137,32 @@ fn test_snowflake_gte() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_gte_classification() -> Result<()> { + let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer + .encode(("What is Deep Learning?", "Deep Learning is not..."), true) + .unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + let predictions_single = SnapshotScores::from(predictions); + + let matcher = relative_matcher(); + insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher); + + Ok(()) +} From 10800969dd8f96dfabfb37262e3033fad0b36382 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 3 May 2025 23:52:39 +0900 Subject: [PATCH 3/6] add: gte_classification_single --- .../tests/snapshots/test_gte__gte_classification_single.snap | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 backends/candle/tests/snapshots/test_gte__gte_classification_single.snap diff --git a/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap b/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap new file mode 100644 index 00000000..b46f4646 --- /dev/null +++ b/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap @@ -0,0 +1,5 @@ +--- +source: backends/candle/tests/test_gte.rs +expression: predictions_single +--- +- - -0.74173266 From 29106a36421a97aa24fb93a92e8cb063d9d96a35 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 3 May 2025 16:10:42 +0000 Subject: [PATCH 4/6] update: test_flash_gte_classification --- .../snapshots/test_flash_gte__gte_classification_single.snap | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap index b45a2643..1acb12d4 100644 --- a/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap +++ b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap @@ -1,6 +1,5 @@ --- source: backends/candle/tests/test_flash_gte.rs -assertion_line: 83 expression: predictions_single --- -- - 0.050048828 +- - -0.7426758 From d6791a6896c518a33edacfd1bc8d773191e09469 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 4 May 2025 12:05:31 +0900 Subject: [PATCH 5/6] update: support new. prefix too --- backends/candle/src/models/gte.rs | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index b2b377cf..fc2714c4 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -277,23 +277,34 @@ impl GTELayer { } pub struct GTEClassificationHead { - pooler: Linear, + pooler: Option, classifier: Linear, span: tracing::Span, } impl GTEClassificationHead { + fn inner_load(vb: VarBuilder, config: >EConfig) -> Result> { + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + Ok(pooler) + } + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { let n_classes = match &config.id2label { None => candle::bail!("`id2label` must be set for classifier models"), Some(id2label) => id2label.len(), }; - let pooler_weight = vb - .pp("new.pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight")?; - let pooler_bias = vb.pp("new.pooler.dense").get(config.hidden_size, "bias")?; - let pooler = Linear::new(pooler_weight, Some(pooler_bias), None); + let pooler = Self::inner_load(vb.pp("new"), config) + .or_else(|_| Self::inner_load(vb.clone(), config))?; let classifier_weight = vb .pp("classifier") @@ -311,10 +322,12 @@ impl GTEClassificationHead { pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); - let hidden_states = hidden_states.unsqueeze(1)?; + let mut hidden_states = hidden_states.unsqueeze(1)?; - let hidden_states = self.pooler.forward(&hidden_states)?; - let hidden_states = hidden_states.tanh()?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } let hidden_states = self.classifier.forward(&hidden_states)?; let hidden_states = hidden_states.squeeze(1)?; From c44648be6b274abc0acfb3eb64f146c80015e392 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 4 May 2025 14:25:24 +0900 Subject: [PATCH 6/6] update: support new. prefix too --- backends/candle/src/models/gte.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index fc2714c4..d5cf3412 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -283,18 +283,15 @@ pub struct GTEClassificationHead { } impl GTEClassificationHead { - fn inner_load(vb: VarBuilder, config: >EConfig) -> Result> { - let pooler = if let Ok(pooler_weight) = vb + fn inner_load(vb: VarBuilder, config: >EConfig) -> Option { + let pooler_weight = vb .pp("pooler.dense") .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; + .ok()?; + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias").ok()?; + let pooler = Linear::new(pooler_weight, Some(pooler_bias), None); - Ok(pooler) + Some(pooler) } pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { @@ -303,8 +300,8 @@ impl GTEClassificationHead { Some(id2label) => id2label.len(), }; - let pooler = Self::inner_load(vb.pp("new"), config) - .or_else(|_| Self::inner_load(vb.clone(), config))?; + let pooler = + Self::inner_load(vb.pp("new"), config).or_else(|| Self::inner_load(vb.clone(), config)); let classifier_weight = vb .pp("classifier")