diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 55a225a4..d5cf3412 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -193,7 +193,6 @@ impl GTEMLP { let up_gate_proj_weight = vb .pp("up_gate_proj") .get((intermediate_size * 2, config.hidden_size), "weight")?; - let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); let down_proj_weight = vb @@ -216,16 +215,12 @@ impl GTEMLP { let up_gate_states = self.up_gate_proj.forward(hidden_states)?; let up_states = up_gate_states.narrow(D::Minus1, 0, self.intermediate_size)?; - let gate_states = - up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate = + up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + let gate = self.act.forward(&gate)?; - self.down_proj.forward(&(gate_states * up_states)?) + self.down_proj.forward(&(gate * up_states)?) } } @@ -288,22 +283,25 @@ pub struct GTEClassificationHead { } impl GTEClassificationHead { - #[allow(dead_code)] + fn inner_load(vb: VarBuilder, config: >EConfig) -> Option { + let pooler_weight = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + .ok()?; + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias").ok()?; + let pooler = Linear::new(pooler_weight, Some(pooler_bias), None); + + Some(pooler) + } + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { let n_classes = match &config.id2label { None => candle::bail!("`id2label` must be set for classifier models"), Some(id2label) => id2label.len(), }; - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; + let pooler = + Self::inner_load(vb.pp("new"), config).or_else(|| Self::inner_load(vb.clone(), config)); let classifier_weight = vb .pp("classifier") @@ -322,6 +320,7 @@ impl GTEClassificationHead { let _enter = self.span.enter(); let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { hidden_states = pooler.forward(&hidden_states)?; hidden_states = hidden_states.tanh()?; @@ -329,6 +328,7 @@ impl GTEClassificationHead { let hidden_states = self.classifier.forward(&hidden_states)?; let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) } } diff --git a/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap index b45a2643..1acb12d4 100644 --- a/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap +++ b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap @@ -1,6 +1,5 @@ --- source: backends/candle/tests/test_flash_gte.rs -assertion_line: 83 expression: predictions_single --- -- - 0.050048828 +- - -0.7426758 diff --git a/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap b/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap new file mode 100644 index 00000000..b46f4646 --- /dev/null +++ b/backends/candle/tests/snapshots/test_gte__gte_classification_single.snap @@ -0,0 +1,5 @@ +--- +source: backends/candle/tests/test_gte.rs +expression: predictions_single +--- +- - -0.74173266 diff --git a/backends/candle/tests/test_gte.rs b/backends/candle/tests/test_gte.rs index b8fe3892..6a9e3e3b 100644 --- a/backends/candle/tests/test_gte.rs +++ b/backends/candle/tests/test_gte.rs @@ -1,8 +1,8 @@ mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -137,3 +137,32 @@ fn test_snowflake_gte() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +fn test_gte_classification() -> Result<()> { + let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer + .encode(("What is Deep Learning?", "Deep Learning is not..."), true) + .unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + let predictions_single = SnapshotScores::from(predictions); + + let matcher = relative_matcher(); + insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher); + + Ok(()) +}