Skip to content

Commit

Permalink
fix clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Feb 7, 2025
1 parent 31801cf commit a881fca
Show file tree
Hide file tree
Showing 24 changed files with 155 additions and 134 deletions.
3 changes: 2 additions & 1 deletion python/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct TextEmbedConfig {
pub inner: embed_anything::config::TextEmbedConfig,
}

#[allow(clippy::too_many_arguments)]
#[pymethods]
impl TextEmbedConfig {
#[new]
Expand Down Expand Up @@ -44,7 +45,7 @@ impl TextEmbedConfig {
.with_buffer_size(buffer_size.unwrap_or(100))
.with_splitting_strategy(strategy.unwrap_or(SplittingStrategy::Sentence))
.with_semantic_encoder(semantic_encoder)
.with_ocr(use_ocr.unwrap_or(false), tesseract_path)
.with_ocr(use_ocr.unwrap_or(false), tesseract_path),
}
}

Expand Down
14 changes: 7 additions & 7 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ impl EmbeddingModel {
let model_id = model_id.unwrap_or("jinaai/jina-embeddings-v2-small-en");
let model = Embedder::Text(TextEmbedder::Jina(Box::new(
embed_anything::embeddings::local::jina::JinaEmbedder::new(
model_id,
revision,
token,
model_id, revision, token,
)
.unwrap(),
)));
Expand Down Expand Up @@ -286,10 +284,12 @@ impl EmbeddingModel {
Some(Dtype::F32) => Some(embed_anything::Dtype::F32),
None => None,
};
let model_name = model_name.map(|model_name| embed_anything::embeddings::local::text_embedding::ONNXModel::from_str(
&model_name.to_string(),
)
.unwrap());
let model_name = model_name.map(|model_name| {
embed_anything::embeddings::local::text_embedding::ONNXModel::from_str(
&model_name.to_string(),
)
.unwrap()
});
match model {
WhichModel::Bert => {
let model = Embedder::Text(TextEmbedder::Bert(Box::new(
Expand Down
2 changes: 1 addition & 1 deletion rust/examples/audio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use embed_anything::{
config::TextEmbedConfig, emb_audio, embeddings::embed:: EmbedderBuilder,
config::TextEmbedConfig, emb_audio, embeddings::embed::EmbedderBuilder,
file_processor::audio::audio_processor::AudioDecoderModel, text_loader::SplittingStrategy,
};

Expand Down
20 changes: 11 additions & 9 deletions rust/examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ use std::{path::PathBuf, time::Instant};

#[tokio::main]
async fn main() {
let model = Arc::new(EmbedderBuilder::new()
.model_architecture("modernbert")
.model_id(Some("nomic-ai/modernbert-embed-base"))
.revision(None)
.token(None)
.dtype(Some(Dtype::F16))
.from_pretrained_hf()
.unwrap());

let model = Arc::new(
EmbedderBuilder::new()
.model_architecture("modernbert")
.model_id(Some("nomic-ai/modernbert-embed-base"))
.revision(None)
.token(None)
.dtype(Some(Dtype::F16))
.from_pretrained_hf()
.unwrap(),
);

let config = TextEmbedConfig::default()
.with_chunk_size(256, Some(0.3))
.with_batch_size(32)
Expand Down
8 changes: 4 additions & 4 deletions rust/examples/ort_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use std::time::Instant;
async fn main() -> Result<(), anyhow::Error> {
let model = Arc::new(
EmbedderBuilder::new()
.model_architecture("bert")
.onnx_model_id(Some(ONNXModel::ModernBERTBase))
.from_pretrained_onnx()
.unwrap()
.model_architecture("bert")
.onnx_model_id(Some(ONNXModel::ModernBERTBase))
.from_pretrained_onnx()
.unwrap(),
);

let config = TextEmbedConfig::default()
Expand Down
1 change: 0 additions & 1 deletion rust/examples/reranker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#[cfg(feature = "ort")]
fn main() {
use embed_anything::Dtype;
Expand Down
2 changes: 1 addition & 1 deletion rust/examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use candle_core::Tensor;
use embed_anything::{
config::TextEmbedConfig,
embed_query, embed_webpage,
embeddings::embed::{EmbedData, EmbedderBuilder},
embeddings::embed::{EmbedData, EmbedderBuilder},
text_loader::SplittingStrategy,
};

Expand Down
7 changes: 4 additions & 3 deletions rust/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl Default for TextEmbedConfig {
}
}

#[allow(clippy::too_many_arguments)]
impl TextEmbedConfig {
pub fn new(
chunk_size: Option<usize>,
Expand Down Expand Up @@ -128,9 +129,9 @@ impl TextEmbedConfig {
self
}

/// Use this to do OCR on the documents to extract text.
/// Set the path to None if you want to use the default path with tesseract installed on your system.
/// You can check if tesseract is installed by running tesseract in your command line.
/// Use this to do OCR on the documents to extract text.
/// Set the path to None if you want to use the default path with tesseract installed on your system.
/// You can check if tesseract is installed by running tesseract in your command line.
/// If you want to use a custom path, you can set the path to the path of the tesseract executable.
pub fn with_ocr(mut self, use_ocr: bool, tesseract_path: Option<&str>) -> Self {
self.use_ocr = Some(use_ocr);
Expand Down
63 changes: 39 additions & 24 deletions rust/src/embeddings/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,30 @@ impl TextEmbedder {
dtype: Option<Dtype>,
) -> Result<Self, anyhow::Error> {
match model {
"jina" | "Jina" => Ok(Self::Jina(Box::new(JinaEmbedder::new(model_id, revision, token)?))),
"jina" | "Jina" => Ok(Self::Jina(Box::new(JinaEmbedder::new(
model_id, revision, token,
)?))),

"Bert" | "bert" => Ok(Self::Bert(Box::new(BertEmbedder::new(
model_id.to_string(),
revision.map(|s| s.to_string()),
token,
)?))),
"sparse-bert" | "SparseBert" | "SPARSE-BERT" => Ok(Self::Bert(Box::new(
SparseBertEmbedder::new(model_id.to_string(), revision.map(|s| s.to_string()), token)?,
))),
"modernbert" | "ModernBert" | "MODERNBERT" => Ok(Self::ModernBert(Box::new(
ModernBertEmbedder::new(model_id.to_string(), revision.map(|s| s.to_string()), token, dtype)?,
))),
"sparse-bert" | "SparseBert" | "SPARSE-BERT" => {
Ok(Self::Bert(Box::new(SparseBertEmbedder::new(
model_id.to_string(),
revision.map(|s| s.to_string()),
token,
)?)))
}
"modernbert" | "ModernBert" | "MODERNBERT" => {
Ok(Self::ModernBert(Box::new(ModernBertEmbedder::new(
model_id.to_string(),
revision.map(|s| s.to_string()),
token,
dtype,
)?)))
}
_ => Err(anyhow::anyhow!("Model not supported")),
}
}
Expand Down Expand Up @@ -313,7 +324,7 @@ impl VisionEmbedder {
/// .from_pretrained_hf()
/// .unwrap();
/// ```
///
///
/// ### Cloud Embedding Model
/// ```rust
/// use embed_anything::embeddings::embed::EmbedderBuilder;
Expand All @@ -324,7 +335,7 @@ impl VisionEmbedder {
/// .from_pretrained_cloud()
/// .unwrap();
/// ```
///
///
/// ### ONNX Embedding Model
/// ```rust,ignore
/// use embed_anything::embeddings::embed::EmbedderBuilder;
Expand All @@ -345,7 +356,7 @@ pub struct EmbedderBuilder {
// Either HF Model ID or the Cloud Model that youu want to use
model_id: Option<String>,
revision: Option<String>,
// The Hugging Face token
// The Hugging Face token
token: Option<String>,
// The API key for the cloud model
api_key: Option<String>,
Expand Down Expand Up @@ -509,20 +520,24 @@ impl Embedder {
token,
dtype,
)?)),
"sparse-bert" | "SparseBert" | "SPARSE-BERT" => Ok(Self::Text(TextEmbedder::from_pretrained_hf(
model_architecture,
model_id,
revision,
token,
dtype,
)?)),
"modernbert" | "ModernBert" | "MODERNBERT" => Ok(Self::Text(TextEmbedder::from_pretrained_hf(
model_architecture,
model_id,
revision,
token,
dtype,
)?)),
"sparse-bert" | "SparseBert" | "SPARSE-BERT" => {
Ok(Self::Text(TextEmbedder::from_pretrained_hf(
model_architecture,
model_id,
revision,
token,
dtype,
)?))
}
"modernbert" | "ModernBert" | "MODERNBERT" => {
Ok(Self::Text(TextEmbedder::from_pretrained_hf(
model_architecture,
model_id,
revision,
token,
dtype,
)?))
}
_ => Err(anyhow::anyhow!("Model not supported")),
}
}
Expand Down
6 changes: 1 addition & 5 deletions rust/src/embeddings/local/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ impl Default for BertEmbedder {
}
}
impl BertEmbedder {
pub fn new(
model_id: String,
revision: Option<String>,
token: Option<&str>,
) -> Result<Self, E> {
pub fn new(model_id: String, revision: Option<String>, token: Option<&str>) -> Result<Self, E> {
let model_info = get_model_info_by_hf_id(&model_id);
let pooling = match model_info {
Some(info) => info
Expand Down
6 changes: 3 additions & 3 deletions rust/src/embeddings/local/colbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ use rayon::{iter::ParallelIterator, slice::ParallelSlice};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};

use crate::embeddings::{
embed::EmbeddingResult,
utils::{get_attention_mask_ndarray, tokenize_batch_ndarray},
};
embed::EmbeddingResult,
utils::{get_attention_mask_ndarray, tokenize_batch_ndarray},
};

use super::bert::{BertEmbed, TokenizerConfig};

Expand Down
10 changes: 5 additions & 5 deletions rust/src/embeddings/local/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ pub mod colpali;
pub mod colpali_ort;
pub mod jina;
pub mod model_info;
pub mod pooling;
pub mod text_embedding;
#[cfg(feature = "ort")]
pub mod ort_jina;
pub mod modernbert;
#[cfg(feature = "ort")]
pub mod ort_bert;
pub mod modernbert;
#[cfg(feature = "ort")]
pub mod ort_jina;
pub mod pooling;
pub mod text_embedding;
39 changes: 26 additions & 13 deletions rust/src/embeddings/local/modernbert.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
use crate::{
embeddings::{normalize_l2, utils::{get_attention_mask, tokenize_batch}},
models::modernbert::{Config, ModernBert}, Dtype,
embeddings::{
normalize_l2,
utils::{get_attention_mask, tokenize_batch},
},
models::modernbert::{Config, ModernBert},
Dtype,
};
use anyhow::Error as E;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::ApiBuilder, Repo};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};

use crate::{
embeddings::{embed::EmbeddingResult, select_device},
models::bert::DTYPE,
};
use crate::embeddings::{embed::EmbeddingResult, select_device};

use super::{bert::BertEmbed, pooling::{ModelOutput, Pooling}};
use super::{
bert::BertEmbed,
pooling::{ModelOutput, Pooling},
};
pub struct ModernBertEmbedder {
pub model: ModernBert,
pub tokenizer: Tokenizer,
Expand All @@ -27,7 +31,12 @@ impl Default for ModernBertEmbedder {
}
}
impl ModernBertEmbedder {
pub fn new(model_id: String, revision: Option<String>, token: Option<&str>, dtype: Option<Dtype>) -> Result<Self, E> {
pub fn new(
model_id: String,
revision: Option<String>,
token: Option<&str>,
dtype: Option<Dtype>,
) -> Result<Self, E> {
let (config_filename, tokenizer_filename, weights_filename) = {
let api = ApiBuilder::new()
.with_token(token.map(|s| s.to_string()))
Expand Down Expand Up @@ -83,7 +92,6 @@ impl ModernBertEmbedder {
_ => DType::F32,
};
let vb = if weights_filename.ends_with("model.safetensors") {

unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device)? }
} else {
println!("Can't find model.safetensors, loading from pytorch_model.bin");
Expand Down Expand Up @@ -112,10 +120,15 @@ impl BertEmbed for ModernBertEmbedder {
let mut encodings: Vec<EmbeddingResult> = Vec::new();

for mini_text_batch in text_batch.chunks(batch_size) {
let token_ids =
tokenize_batch(&self.tokenizer, mini_text_batch, &self.device).unwrap();
let attention_mask = get_attention_mask(&self.tokenizer, mini_text_batch, &self.device).unwrap();
let embeddings: Tensor = self.model.forward(&token_ids, &attention_mask).unwrap().to_dtype(DType::F32).unwrap();
let token_ids = tokenize_batch(&self.tokenizer, mini_text_batch, &self.device).unwrap();
let attention_mask =
get_attention_mask(&self.tokenizer, mini_text_batch, &self.device).unwrap();
let embeddings: Tensor = self
.model
.forward(&token_ids, &attention_mask)
.unwrap()
.to_dtype(DType::F32)
.unwrap();
let pooled_output = self
.pooling
.pool(&ModelOutput::Tensor(embeddings.clone()))
Expand Down
4 changes: 2 additions & 2 deletions rust/src/embeddings/local/ort_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use super::bert::{BertEmbed, TokenizerConfig};
use super::pooling::{ModelOutput, Pooling};
use super::text_embedding::ONNXModel;
use crate::embeddings::embed::EmbeddingResult;
use crate::embeddings::local::text_embedding::models_map;
use crate::embeddings::utils::{
get_attention_mask_ndarray, get_type_ids_ndarray, tokenize_batch_ndarray,
};
use crate::embeddings::local::text_embedding::models_map;

use crate::Dtype;
use anyhow::Error as E;
use hf_hub::api::sync::Api;
use hf_hub::Repo;
use ndarray::prelude::*;
Expand All @@ -17,7 +18,6 @@ use ort::session::Session;
use ort::value::Value;
use rayon::prelude::*;
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
use anyhow::Error as E;

#[derive(Debug)]
pub struct OrtBertEmbedder {
Expand Down
Loading

0 comments on commit a881fca

Please sign in to comment.