From ffa67ae06ed83e5e1b9828b69413e7be2858a7a9 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 5 Apr 2025 14:25:49 +0200 Subject: [PATCH 1/2] Enable ModernBert on metal --- backends/candle/src/lib.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 4d66a4f6..c841aaf4 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -278,11 +278,6 @@ impl CandleBackend { Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?)) } (Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => match device { - Device::Metal(_) => { - return Err(BackendError::Start( - "ModernBert is not currently supported on MPS device".to_string(), - )); - } _ => { tracing::info!("Starting ModernBert model on {:?}", device); Ok(Box::new( From 14bc99c386c8a99e5d3cc1b468a5d2e1788c5a6a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 5 Apr 2025 17:35:12 +0200 Subject: [PATCH 2/2] Remove redundant match --- backends/candle/src/lib.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index c841aaf4..a9ddea97 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -277,14 +277,12 @@ impl CandleBackend { tracing::info!("Starting MPNet model on {:?}", device); Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?)) } - (Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => match device { - _ => { - tracing::info!("Starting ModernBert model on {:?}", device); - Ok(Box::new( - ModernBertModel::load(vb, &config, model_type).s()?, - )) - } - }, + (Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting ModernBert model on {:?}", device); + Ok(Box::new( + ModernBertModel::load(vb, &config, model_type).s()?, + )) + } #[cfg(feature = "cuda")] (Config::Bert(config), Device::Cuda(_)) => { if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))