diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 4e6cea44..d75841f1 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -31,7 +31,7 @@ impl OrtBackend { let pool = match model_type { ModelType::Classifier => Pool::Cls, ModelType::Embedding(pool) => match pool { - Pool::Splade | Pool::LastToken => { + Pool::Splade => { return Err(BackendError::Start(format!( "Pooling {pool} is not supported for this backend. Use `candle` backend instead." ))); @@ -204,8 +204,10 @@ impl Backend for OrtBackend { let pooled_embeddings = match self.pool { // CLS pooling Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), - // Last token pooling is not supported for this model - Pool::LastToken => unreachable!(), + Pool::LastToken => { + let axis_len = outputs.len_of(Axis(1)); + outputs.slice(s![.., axis_len - 1, ..]).into_owned().into_dyn() + }, // Mean pooling Pool::Mean => { if masking {