Skip to content

Commit

Permalink
fix modernbert embed
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Jan 30, 2025
1 parent ab61898 commit 1e3ebaa
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 63 deletions.
4 changes: 2 additions & 2 deletions rust/examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use std::{path::PathBuf, time::Instant};
#[tokio::main]
async fn main() {
let model = Arc::new(EmbedderBuilder::new()
.model_architecture("jina")
.model_id(Some("jinaai/jina-embeddings-v2-small-en"))
.model_architecture("modernbert")
.model_id(Some("nomic-ai/modernbert-embed-base"))
.revision(None)
.token(None)
.from_pretrained_hf()
Expand Down
16 changes: 6 additions & 10 deletions rust/examples/ort_models.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use candle_core::{Device, Tensor};
use embed_anything::config::TextEmbedConfig;
use embed_anything::embeddings::embed::{EmbedData, Embedder};
use embed_anything::embeddings::embed::{EmbedData, Embedder, EmbedderBuilder};
use embed_anything::embeddings::local::text_embedding::ONNXModel;
use embed_anything::text_loader::SplittingStrategy;
use embed_anything::Dtype;
Expand All @@ -12,15 +12,11 @@ use std::time::Instant;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let model = Arc::new(
Embedder::from_pretrained_onnx(
"jina",
Some(ONNXModel::JINAV3),
None,
None,
Some(Dtype::F16),
None,
)
.unwrap(),
EmbedderBuilder::new()
.model_architecture("bert")
.onnx_model_id(Some(ONNXModel::ModernBERTBase))
.from_pretrained_onnx()
.unwrap()
);

let config = TextEmbedConfig::default()
Expand Down
56 changes: 5 additions & 51 deletions rust/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,30 +237,6 @@ impl Module for ModernBertHead {
}
}

#[derive(Clone)]
pub struct ModernBertDecoder {
decoder: Linear,
}

impl ModernBertDecoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
// The decoder weights are tied with the embeddings layer weights
let decoder_weights = vb.get(
(config.vocab_size, config.hidden_size),
"model.embeddings.tok_embeddings.weight",
)?;
let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
let decoder = Linear::new(decoder_weights, Some(decoder_bias));
Ok(Self { decoder })
}
}

impl Module for ModernBertDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.decoder)?;
Ok(xs)
}
}

// Global attention mask calculated from padded token inputs
fn prepare_4d_attention_mask(
Expand Down Expand Up @@ -310,7 +286,6 @@ pub struct ModernBert {
norm: LayerNorm,
layers: Vec<ModernBertLayer>,
final_norm: LayerNorm,
head: ModernBertHead,
local_attention_size: usize,
}

Expand All @@ -319,12 +294,12 @@ impl ModernBert {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("model.embeddings.tok_embeddings"),
vb.pp("embeddings.tok_embeddings"),
)?;
let norm = layer_norm_no_bias(
config.hidden_size,
config.layer_norm_eps,
vb.pp("model.embeddings.norm"),
vb.pp("embeddings.norm"),
)?;
let global_rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
Expand All @@ -343,7 +318,7 @@ impl ModernBert {
for layer_id in 0..config.num_hidden_layers {
let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
layers.push(ModernBertLayer::load(
vb.pp(format!("model.layers.{layer_id}")),
vb.pp(format!("layers.{layer_id}")),
config,
if layer_uses_local_attention {
local_rotary_emb.clone()
Expand All @@ -357,16 +332,14 @@ impl ModernBert {
let final_norm = layer_norm_no_bias(
config.hidden_size,
config.layer_norm_eps,
vb.pp("model.final_norm"),
vb.pp("final_norm"),
)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;

Ok(Self {
word_embeddings,
norm,
layers,
final_norm,
head,
local_attention_size: config.local_attention,
})
}
Expand All @@ -381,27 +354,8 @@ impl ModernBert {
for layer in self.layers.iter() {
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
}
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
let xs = xs.apply(&self.final_norm)?;
Ok(xs)
}
}

// ModernBERT for the fill-mask task
#[derive(Clone)]
pub struct ModernBertForMaskedLM {
model: ModernBert,
decoder: ModernBertDecoder,
}

impl ModernBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
Ok(Self { model, decoder })
}

pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
Ok(xs)
}
}

0 comments on commit 1e3ebaa

Please sign in to comment.