Skip to content

Commit

Permalink
Merge pull request #118 from boswelja/semantic-splitting-require-encoder
Browse files Browse the repository at this point in the history
Make SplittingStrategy::Semantic require an encoder
  • Loading branch information
akshayballal95 authored Feb 11, 2025
2 parents a881fca + cf62226 commit b748078
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 88 deletions.
28 changes: 14 additions & 14 deletions python/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::sync::Arc;

use embed_anything::text_loader::SplittingStrategy;
use pyo3::prelude::*;

use embed_anything::config::SplittingStrategy;
use crate::EmbeddingModel;

#[pyclass]
Expand All @@ -28,23 +25,26 @@ impl TextEmbedConfig {
) -> Self {
let strategy = match splitting_strategy {
Some(strategy) => match strategy {
"sentence" => Some(SplittingStrategy::Sentence),
"semantic" => Some(SplittingStrategy::Semantic),
_ => None,
"sentence" => SplittingStrategy::Sentence,
"semantic" => {
if semantic_encoder.is_none() {
panic!("Semantic encoder is required when using Semantic splitting strategy");
}
SplittingStrategy::Semantic {
semantic_encoder: semantic_encoder.unwrap().inner.clone()
}
},
_ => panic!("Unknown strategy provided!"),
},
None => None,
None => SplittingStrategy::Sentence,
};
let semantic_encoder = semantic_encoder.map(|model| Arc::clone(&model.inner));
if matches!(strategy, Some(SplittingStrategy::Semantic)) && semantic_encoder.is_none() {
panic!("Semantic encoder is required when using Semantic splitting strategy");
}

Self {
inner: embed_anything::config::TextEmbedConfig::default()
.with_chunk_size(chunk_size.unwrap_or(256), overlap_ratio)
.with_batch_size(batch_size.unwrap_or(32))
.with_buffer_size(buffer_size.unwrap_or(100))
.with_splitting_strategy(strategy.unwrap_or(SplittingStrategy::Sentence))
.with_semantic_encoder(semantic_encoder)
.with_splitting_strategy(strategy)
.with_ocr(use_ocr.unwrap_or(false), tesseract_path),
}
}
Expand Down
69 changes: 27 additions & 42 deletions rust/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
use crate::embeddings::embed::Embedder;
use std::sync::Arc;

use crate::{embeddings::embed::Embedder, text_loader::SplittingStrategy};

/// Configuration for text embedding.
///
/// # Example: Creating a new instance
///
/// ```rust
/// use embed_anything::config::TextEmbedConfig;
/// use embed_anything::text_loader::SplittingStrategy;
/// use embed_anything::config::{TextEmbedConfig, SplittingStrategy};
/// let config = TextEmbedConfig::new(
/// Some(512),
/// Some(128),
/// Some(100),
/// Some(0.0),
/// Some(SplittingStrategy::Sentence),
/// None,
/// Some(true)
/// SplittingStrategy::Sentence,
/// Some(true),
/// None
/// );
/// ```
///
/// # Example: Overriding a single default
///
/// ```rust
/// use embed_anything::config::TextEmbedConfig;
/// use embed_anything::text_loader::SplittingStrategy;
/// use embed_anything::config::{TextEmbedConfig, SplittingStrategy};
/// let config = TextEmbedConfig {
/// splitting_strategy: Some(SplittingStrategy::Semantic),
/// splitting_strategy: SplittingStrategy::Semantic,
/// ..Default::default()
/// };
/// ```
Expand All @@ -45,10 +42,7 @@ pub struct TextEmbedConfig {
pub buffer_size: Option<usize>,
/// Controls how documents are split into segments. See [SplittingStrategy] for options.
/// Defaults to [SplittingStrategy::Sentence]
pub splitting_strategy: Option<SplittingStrategy>,
/// Allows overriding the embedder used when the splitting strategy is
/// [SplittingStrategy::Semantic]. Defaults to JINA.
pub semantic_encoder: Option<Arc<Embedder>>,
pub splitting_strategy: SplittingStrategy,
/// When embedding a PDF, controls whether **o**ptical **c**haracter **r**ecognition is used on
/// the PDF to extract text. This process involves rendering the PDF as a series of images, and
/// extracting text from the images. Defaults to false.
Expand All @@ -63,8 +57,7 @@ impl Default for TextEmbedConfig {
overlap_ratio: Some(0.0),
batch_size: Some(32),
buffer_size: Some(100),
splitting_strategy: None,
semantic_encoder: None,
splitting_strategy: SplittingStrategy::Sentence,
use_ocr: None,
tesseract_path: None,
}
Expand All @@ -78,29 +71,17 @@ impl TextEmbedConfig {
batch_size: Option<usize>,
buffer_size: Option<usize>,
overlap_ratio: Option<f32>,
splitting_strategy: Option<SplittingStrategy>,
semantic_encoder: Option<Arc<Embedder>>,
splitting_strategy: SplittingStrategy,
use_ocr: Option<bool>,
tesseract_path: Option<String>,
) -> Self {
let config = Self::default()
Self::default()
.with_chunk_size(chunk_size.unwrap_or(256), overlap_ratio)
.with_batch_size(batch_size.unwrap_or(32))
.with_buffer_size(buffer_size.unwrap_or(100))
.with_ocr(use_ocr.unwrap_or(false), tesseract_path.as_deref());

match splitting_strategy {
Some(SplittingStrategy::Semantic) => {
if semantic_encoder.is_none() {
panic!("Semantic encoder is required when using Semantic splitting strategy");
}
config
.with_semantic_encoder(Some(semantic_encoder.unwrap()))
.with_splitting_strategy(SplittingStrategy::Semantic)
}
Some(strategy) => config.with_splitting_strategy(strategy),
None => config,
}
.with_ocr(use_ocr.unwrap_or(false), tesseract_path.as_deref())
.with_splitting_strategy(splitting_strategy)
.build()
}

pub fn with_chunk_size(mut self, size: usize, overlap_ratio: Option<f32>) -> Self {
Expand All @@ -120,12 +101,7 @@ impl TextEmbedConfig {
}

pub fn with_splitting_strategy(mut self, strategy: SplittingStrategy) -> Self {
self.splitting_strategy = Some(strategy);
self
}

pub fn with_semantic_encoder(mut self, encoder: Option<Arc<Embedder>>) -> Self {
self.semantic_encoder = encoder;
self.splitting_strategy = strategy;
self
}

Expand All @@ -140,13 +116,22 @@ impl TextEmbedConfig {
}

pub fn build(self) -> TextEmbedConfig {
if self.semantic_encoder.is_none() && self.splitting_strategy.is_some() {
panic!("Semantic encoder is required when using Semantic splitting strategy");
}
self
}
}

#[derive(Clone)]
pub enum SplittingStrategy {
/// Splits text-based content by sentence, resulting in one embedding per sentence.
Sentence,
/// Uses an embedder to determine semantic relevance of chunks of text. Produces embeddings that
/// may be longer, or shorter than a sentence.
Semantic {
/// Specifies the embedder used when the splitting semantically.
semantic_encoder: Arc<Embedder>
},
}

#[derive(Clone)]
pub struct ImageEmbedConfig {
pub buffer_size: Option<usize>, // Required for adapter. Default is 100.
Expand Down
5 changes: 3 additions & 2 deletions rust/src/file_processor/html_processor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::embeddings::embed::{EmbedData, Embedder};
use crate::embeddings::get_text_metadata;
use crate::text_loader::{SplittingStrategy, TextLoader};
use crate::text_loader::TextLoader;
use anyhow::Result;
use scraper::{Html, Selector};
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use url::Url;
use crate::config::SplittingStrategy;

#[derive(Debug)]
pub struct HtmlDocument {
Expand Down Expand Up @@ -87,7 +88,7 @@ impl HtmlDocument {
for content in tag_content {
let textloader = TextLoader::new(chunk_size, overlap_ratio);
let chunks =
match textloader.split_into_chunks(content, SplittingStrategy::Sentence, None) {
match textloader.split_into_chunks(content, SplittingStrategy::Sentence) {
Some(chunks) => chunks,
None => continue,
};
Expand Down
5 changes: 3 additions & 2 deletions rust/src/file_processor/website_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::{
get_text_metadata,
},
file_processor::html_processor::HtmlProcessor,
text_loader::{SplittingStrategy, TextLoader},
text_loader::TextLoader,
};
use crate::config::SplittingStrategy;

#[derive(Debug)]
pub struct WebPage {
Expand Down Expand Up @@ -94,7 +95,7 @@ impl WebPage {
for content in tag_content {
let textloader = TextLoader::new(chunk_size, overlap_ratio);
let chunks =
match textloader.split_into_chunks(content, SplittingStrategy::Sentence, None) {
match textloader.split_into_chunks(content, SplittingStrategy::Sentence) {
Some(chunks) => chunks,
None => continue,
};
Expand Down
12 changes: 5 additions & 7 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ use file_loader::FileParser;
use file_processor::audio::audio_processor::AudioDecoderModel;
use itertools::Itertools;
use rayon::prelude::*;
use text_loader::{SplittingStrategy, TextLoader};
use text_loader::TextLoader;
use tokio::sync::mpsc; // Add this at the top of your file

#[cfg(feature = "audio")]
use embeddings::embed_audio;
use crate::config::SplittingStrategy;

pub enum Dtype {
F16,
Expand Down Expand Up @@ -328,16 +329,13 @@ where
let chunk_size = config.chunk_size.unwrap_or(256);
let overlap_ratio = config.overlap_ratio.unwrap_or(0.0);
let batch_size = config.batch_size;
let splitting_strategy = config
.splitting_strategy
.unwrap_or(SplittingStrategy::Sentence);
let semantic_encoder = config.semantic_encoder.clone();
let splitting_strategy = config.splitting_strategy.clone();
let use_ocr = config.use_ocr.unwrap_or(false);
let tesseract_path = config.tesseract_path.clone();
let text = TextLoader::extract_text(&file, use_ocr, tesseract_path.as_deref())?;
let textloader = TextLoader::new(chunk_size, overlap_ratio);
let chunks = textloader
.split_into_chunks(&text, splitting_strategy, semantic_encoder)
.split_into_chunks(&text, splitting_strategy)
.unwrap_or_default();

let metadata = TextLoader::get_metadata(file).ok();
Expand Down Expand Up @@ -711,7 +709,7 @@ where
}
};
let chunks = textloader
.split_into_chunks(&text, SplittingStrategy::Sentence, None)
.split_into_chunks(&text, SplittingStrategy::Sentence)
.unwrap_or_else(|| vec![text.clone()])
.into_iter()
.filter(|chunk| !chunk.trim().is_empty())
Expand Down
30 changes: 9 additions & 21 deletions rust/src/text_loader.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
use std::{
collections::HashMap,
fmt::{Debug, Display},
fs,
sync::Arc,
fs
,
};

use crate::{
chunkers::statistical::StatisticalChunker,
embeddings::{embed::TextEmbedder, local::jina::JinaEmbedder},
chunkers::statistical::StatisticalChunker
,
file_processor::docx_processor::DocxProcessor,
};
use crate::{
embeddings::embed::Embedder,
file_processor::{markdown_processor::MarkdownProcessor, txt_processor::TxtProcessor},
};
use crate::file_processor::{markdown_processor::MarkdownProcessor, txt_processor::TxtProcessor};
use anyhow::Error;
use chrono::{DateTime, Local};
use text_splitter::{ChunkConfig, TextSplitter};
use tokenizers::Tokenizer;

use super::file_processor::pdf_processor::PdfProcessor;
use crate::config::SplittingStrategy;
use rayon::prelude::*;

#[derive(Clone, Copy)]
pub enum SplittingStrategy {
Sentence,
Semantic,
}

impl Default for TextLoader {
fn default() -> Self {
Self::new(256, 0.0)
Expand Down Expand Up @@ -86,7 +78,6 @@ impl TextLoader {
&self,
text: &str,
splitting_strategy: SplittingStrategy,
semantic_encoder: Option<Arc<Embedder>>,
) -> Option<Vec<String>> {
if text.is_empty() {
return None;
Expand All @@ -104,12 +95,9 @@ impl TextLoader {
.par_bridge()
.map(|chunk| chunk.to_string())
.collect(),
SplittingStrategy::Semantic => {
let embedder = semantic_encoder.unwrap_or(Arc::new(Embedder::Text(
TextEmbedder::Jina(Box::new(JinaEmbedder::default())),
)));
SplittingStrategy::Semantic { semantic_encoder } => {
let chunker = StatisticalChunker {
encoder: embedder,
encoder: semantic_encoder,
..Default::default()
};

Expand Down Expand Up @@ -192,7 +180,7 @@ mod tests {
.replace(" ", " ");

let text_loader = TextLoader::new(256, 0.0);
let chunks = text_loader.split_into_chunks(&text, SplittingStrategy::Sentence, None);
let chunks = text_loader.split_into_chunks(&text, SplittingStrategy::Sentence);

for chunk in chunks.unwrap() {
println!("-----------------------------------");
Expand Down

0 comments on commit b748078

Please sign in to comment.