Skip to content

Commit

Permalink
Merge pull request StarlightSearch#74 from akshayballal95/main
Browse files Browse the repository at this point in the history
Fix cloud embeddings

Former-commit-id: 169f079803132f6e7e79e4b5f91cf03a6e2c5d7f [formerly 3c7e62fb510c5b109d5a7d5d122a396032db08ef] [formerly 25afbc3 [formerly a1877ae]]
Former-commit-id: 25afbc3
Former-commit-id: d34d4ba697f5ec3eec8816dccb5b1f4dc48d7a61
  • Loading branch information
akshayballal95 authored Sep 11, 2024
2 parents fe1c028 + c0632ee commit 7222ae9
Show file tree
Hide file tree
Showing 25 changed files with 299 additions and 214 deletions.
40 changes: 40 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions examples/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel
from embed_anything.vectordb import Adapter
from pinecone import Pinecone, ServerlessSpec
import os
from time import time


model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
# model = EmbeddingModel.from_pretrained_hf(
# WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
# )

model = EmbeddingModel.from_pretrained_cloud(
WhichModel.OpenAI, model_id="text-embedding-3-small"
)
config = TextEmbedConfig(chunk_size=512, batch_size=32)

Expand Down
1 change: 1 addition & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ tokio = { version = "1.39.0", features = ["rt-multi-thread"]}
extension-module = ["pyo3/extension-module"]
mkl = ["embed_anything/mkl"]
accelerate = ["embed_anything/accelerate"]
cuda = ["embed_anything/cuda"]
99 changes: 50 additions & 49 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use embed_anything::{
self, config::TextEmbedConfig, emb_audio, embeddings::embed::Embeder,
file_processor::audio::audio_processor, text_loader::FileLoadingError,
};
use pyo3::{exceptions::PyValueError, exceptions::PyFileNotFoundError, prelude::*};
use pyo3::{exceptions::PyFileNotFoundError, exceptions::PyValueError, prelude::*};
use std::{
collections::HashMap,
path::{Path, PathBuf},
Expand Down Expand Up @@ -218,16 +218,20 @@ pub fn embed_query(
) -> PyResult<Vec<EmbedData>> {
let config = config.map(|c| &c.inner);
let embedding_model = &embeder.inner;

Ok(embed_anything::embed_query(
query,
embedding_model,
Some(config.unwrap_or(&TextEmbedConfig::default())),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.into_iter()
.map(|data| EmbedData { inner: data })
.collect())
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
Ok(rt.block_on(async {
embed_anything::embed_query(
query,
embedding_model,
Some(config.unwrap_or(&TextEmbedConfig::default())),
)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.into_iter()
.map(|data| EmbedData { inner: data })
.collect()
}))
}

#[pyfunction]
Expand All @@ -240,6 +244,7 @@ pub fn embed_file(
) -> PyResult<Option<Vec<EmbedData>>> {
let config = config.map(|c| &c.inner);
let embedding_model = &embeder.inner;
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
if !Path::new(file_name).exists() {
// check if the file exists other wise return a "File not found" error with PyValueError
return Err(PyFileNotFoundError::new_err(format!(
Expand Down Expand Up @@ -267,27 +272,15 @@ pub fn embed_file(
None => None,
};

let data = embed_anything::embed_file(file_name, &embedding_model, config, adapter)
.map_err(|e| {
if let Some(file_loading_error) = e.downcast_ref::<FileLoadingError>() {
match file_loading_error {
FileLoadingError::FileNotFound(file) => {
PyFileNotFoundError::new_err(file.clone())
}
FileLoadingError::UnsupportedFileType(file) => {
PyValueError::new_err(file.clone())
}
}
} else {
PyValueError::new_err(e.to_string())
}
})?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
Ok(data)
let embeddings = rt.block_on(async {
embed_anything::embed_file(file_name, &embedding_model, config, adapter).await
}).map_err(|e| match e.downcast_ref::<FileLoadingError>() {
Some(FileLoadingError::FileNotFound(file)) => PyFileNotFoundError::new_err(file.clone()),
Some(FileLoadingError::UnsupportedFileType(file)) => PyValueError::new_err(file.clone()),
None => PyValueError::new_err(e.to_string()),
})?;

Ok(embeddings.map(|embs| embs.into_iter().map(|data| EmbedData { inner: data }).collect()))
}

#[pyfunction]
Expand All @@ -301,14 +294,18 @@ pub fn embed_audio_file(
let config = text_embed_config.map(|c| &c.inner);
let embedding_model = &embeder.inner;
let audio_decoder = &mut audio_decoder.inner;

let data = emb_audio(audio_file, audio_decoder, embedding_model, config)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
let data = rt.block_on(async {
emb_audio(audio_file, audio_decoder, embedding_model, config)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
})
});
Ok(data)
}

Expand Down Expand Up @@ -339,7 +336,7 @@ pub fn embed_directory(
.call1(py, (converted_data,))
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap();
})
});
};
Some(callback)
}
Expand Down Expand Up @@ -422,7 +419,7 @@ pub fn embed_webpage(
) -> PyResult<Option<Vec<EmbedData>>> {
let embedding_model = &embeder.inner;
let config = config.map(|c| &c.inner);

let rt = Builder::new_multi_thread().enable_all().build().unwrap();
let adapter = match adapter {
Some(adapter) => {
let callback = move |data: Vec<embed_anything::embeddings::embed::EmbedData>| {
Expand All @@ -443,13 +440,17 @@ pub fn embed_webpage(
None => None,
};

let data = embed_anything::embed_webpage(url, embedding_model, config, adapter)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
let data = rt.block_on(async {
embed_anything::embed_webpage(url, embedding_model, config, adapter)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
})
});
Ok(data)
}

Expand Down
2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ itertools = "0.13.0"
symphonia = { version = "0.5.3", features = ["all"]}
byteorder = "1.5.0"

futures = "0.3.30"

# Optional Dependency
intel-mkl-src = { version = "0.8.1", optional = true }
accelerate-src = { version = "0.3.2", optional = true }
Expand Down
5 changes: 4 additions & 1 deletion rust/examples/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use embed_anything::{
file_processor::audio::audio_processor::AudioDecoderModel,
};

fn main() {
#[tokio::main]
async fn main() {
let audio_path = std::path::PathBuf::from("test_files/audio/samples_hp0.wav");
let mut audio_decoder = AudioDecoderModel::from_pretrained(
Some("openai/whisper-tiny.en"),
Expand All @@ -24,6 +25,8 @@ fn main() {
&bert_model,
Some(&text_embed_config),
)
.await
.unwrap()
.unwrap();

println!("{:?}", embeddings);
Expand Down
11 changes: 11 additions & 0 deletions rust/examples/chunkers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use embed_anything::text_loader::TextLoader;
use embed_anything::chunkers::statistical::StatisticalChunker;
fn main() {
let text = TextLoader::extract_text("/home/akshay/EmbedAnything/test_files/attention.pdf").unwrap();
let chunker = StatisticalChunker{
verbose: true,
..Default::default()
};
let chunks = chunker._chunk(&text, 32);

}
2 changes: 1 addition & 1 deletion rust/examples/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() {
.unwrap()
.unwrap();

let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], &model, None).unwrap();
let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], &model, None).await.unwrap();
let n_vectors = out.len();
let out_embeddings = Tensor::from_vec(
out.iter()
Expand Down
18 changes: 15 additions & 3 deletions rust/examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{path::PathBuf, sync::Arc};

use embed_anything::{
config::TextEmbedConfig,
embed_directory_stream, embed_file,
embed_directory_stream, embed_file, embed_query,
embeddings::{
cloud::cohere::CohereEmbeder,
embed::{EmbedData, Embeder},
Expand All @@ -29,12 +29,24 @@ async fn main() -> Result<()> {
.await?
.unwrap();

let _cohere_embedding: Option<Vec<EmbedData>> = embed_file(

let _file_embedding = embed_file(
"test_files/attention.pdf",
&openai_model,
Some(&text_embed_config),
None::<fn(Vec<EmbedData>)>,
).await
?
.unwrap();

let _cohere_embedding = embed_file(
"test_files/attention.pdf",
&cohere_model,
Some(&text_embed_config),
None::<fn(Vec<EmbedData>)>,
)?;
)
.await?
.unwrap();

Ok(())
}
9 changes: 5 additions & 4 deletions rust/examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use embed_anything::{
embeddings::embed::{EmbedData, Embeder},
};

fn main() {
#[tokio::main]
async fn main() {
let start_time = std::time::Instant::now();
let url = "https://www.scrapingbee.com/blog/web-scraping-rust/".to_string();

Expand All @@ -19,9 +20,8 @@ fn main() {
&embeder,
Some(&embed_config),
None::<fn(Vec<EmbedData>)>,
)
.unwrap()
.unwrap();
).await
.unwrap().unwrap();
let embeddings: Vec<Vec<f32>> = embed_data
.iter()
.map(|data| data.embedding.clone())
Expand All @@ -37,6 +37,7 @@ fn main() {

let query = vec!["Rust for web scraping".to_string()];
let query_embedding: Vec<f32> = embed_query(query, &embeder, Some(&embed_config))
.await
.unwrap()
.iter()
.map(|data| data.embedding.clone())
Expand Down
4 changes: 3 additions & 1 deletion rust/src/chunkers/cumulative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
}
}

pub fn _chunk(&self, text: &str) {
pub async fn _chunk(&self, text: &str) {
let splits = self
.splitter
.chunks(text)
Expand Down Expand Up @@ -70,13 +70,15 @@ impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
let curr_chunk_docs_embed = self
.encoder
.embed(&[curr_chunk_docs.to_string()], Some(32))

.unwrap()
.into_iter()
.flatten()
.collect::<Vec<_>>();
let next_doc_embed = self
.encoder
.embed(&[next_doc.to_string()], Some(32))

.unwrap()
.into_iter()
.flatten()
Expand Down
Loading

0 comments on commit 7222ae9

Please sign in to comment.