Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cloud embeddings #74

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions examples/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel
from embed_anything.vectordb import Adapter
from pinecone import Pinecone, ServerlessSpec
import os
from time import time


model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
# model = EmbeddingModel.from_pretrained_hf(
# WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
# )

model = EmbeddingModel.from_pretrained_cloud(
WhichModel.OpenAI, model_id="text-embedding-3-small"
)
config = TextEmbedConfig(chunk_size=512, batch_size=32)

Expand Down
1 change: 1 addition & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ tokio = { version = "1.39.0", features = ["rt-multi-thread"]}
extension-module = ["pyo3/extension-module"]
mkl = ["embed_anything/mkl"]
accelerate = ["embed_anything/accelerate"]
cuda = ["embed_anything/cuda"]
99 changes: 50 additions & 49 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use embed_anything::{
self, config::TextEmbedConfig, emb_audio, embeddings::embed::Embeder,
file_processor::audio::audio_processor, text_loader::FileLoadingError,
};
use pyo3::{exceptions::PyValueError, exceptions::PyFileNotFoundError, prelude::*};
use pyo3::{exceptions::PyFileNotFoundError, exceptions::PyValueError, prelude::*};
use std::{
collections::HashMap,
path::{Path, PathBuf},
Expand Down Expand Up @@ -218,16 +218,20 @@ pub fn embed_query(
) -> PyResult<Vec<EmbedData>> {
let config = config.map(|c| &c.inner);
let embedding_model = &embeder.inner;

Ok(embed_anything::embed_query(
query,
embedding_model,
Some(config.unwrap_or(&TextEmbedConfig::default())),
)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.into_iter()
.map(|data| EmbedData { inner: data })
.collect())
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
Ok(rt.block_on(async {
embed_anything::embed_query(
query,
embedding_model,
Some(config.unwrap_or(&TextEmbedConfig::default())),
)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.into_iter()
.map(|data| EmbedData { inner: data })
.collect()
}))
}

#[pyfunction]
Expand All @@ -240,6 +244,7 @@ pub fn embed_file(
) -> PyResult<Option<Vec<EmbedData>>> {
let config = config.map(|c| &c.inner);
let embedding_model = &embeder.inner;
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
if !Path::new(file_name).exists() {
// check if the file exists other wise return a "File not found" error with PyValueError
return Err(PyFileNotFoundError::new_err(format!(
Expand Down Expand Up @@ -267,27 +272,15 @@ pub fn embed_file(
None => None,
};

let data = embed_anything::embed_file(file_name, &embedding_model, config, adapter)
.map_err(|e| {
if let Some(file_loading_error) = e.downcast_ref::<FileLoadingError>() {
match file_loading_error {
FileLoadingError::FileNotFound(file) => {
PyFileNotFoundError::new_err(file.clone())
}
FileLoadingError::UnsupportedFileType(file) => {
PyValueError::new_err(file.clone())
}
}
} else {
PyValueError::new_err(e.to_string())
}
})?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
Ok(data)
let embeddings = rt.block_on(async {
embed_anything::embed_file(file_name, &embedding_model, config, adapter).await
}).map_err(|e| match e.downcast_ref::<FileLoadingError>() {
Some(FileLoadingError::FileNotFound(file)) => PyFileNotFoundError::new_err(file.clone()),
Some(FileLoadingError::UnsupportedFileType(file)) => PyValueError::new_err(file.clone()),
None => PyValueError::new_err(e.to_string()),
})?;

Ok(embeddings.map(|embs| embs.into_iter().map(|data| EmbedData { inner: data }).collect()))
}

#[pyfunction]
Expand All @@ -301,14 +294,18 @@ pub fn embed_audio_file(
let config = text_embed_config.map(|c| &c.inner);
let embedding_model = &embeder.inner;
let audio_decoder = &mut audio_decoder.inner;

let data = emb_audio(audio_file, audio_decoder, embedding_model, config)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
let data = rt.block_on(async {
emb_audio(audio_file, audio_decoder, embedding_model, config)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
})
});
Ok(data)
}

Expand Down Expand Up @@ -339,7 +336,7 @@ pub fn embed_directory(
.call1(py, (converted_data,))
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap();
})
});
};
Some(callback)
}
Expand Down Expand Up @@ -422,7 +419,7 @@ pub fn embed_webpage(
) -> PyResult<Option<Vec<EmbedData>>> {
let embedding_model = &embeder.inner;
let config = config.map(|c| &c.inner);

let rt = Builder::new_multi_thread().enable_all().build().unwrap();
let adapter = match adapter {
Some(adapter) => {
let callback = move |data: Vec<embed_anything::embeddings::embed::EmbedData>| {
Expand All @@ -443,13 +440,17 @@ pub fn embed_webpage(
None => None,
};

let data = embed_anything::embed_webpage(url, embedding_model, config, adapter)
.map_err(|e| PyValueError::new_err(e.to_string()))?
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
});
let data = rt.block_on(async {
embed_anything::embed_webpage(url, embedding_model, config, adapter)
.await
.map_err(|e| PyValueError::new_err(e.to_string()))
.unwrap()
.map(|data| {
data.into_iter()
.map(|data| EmbedData { inner: data })
.collect::<Vec<_>>()
})
});
Ok(data)
}

Expand Down
2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ itertools = "0.13.0"
symphonia = { version = "0.5.3", features = ["all"]}
byteorder = "1.5.0"

futures = "0.3.30"

# Optional Dependency
intel-mkl-src = { version = "0.8.1", optional = true }
accelerate-src = { version = "0.3.2", optional = true }
Expand Down
5 changes: 4 additions & 1 deletion rust/examples/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use embed_anything::{
file_processor::audio::audio_processor::AudioDecoderModel,
};

fn main() {
#[tokio::main]
async fn main() {
let audio_path = std::path::PathBuf::from("test_files/audio/samples_hp0.wav");
let mut audio_decoder = AudioDecoderModel::from_pretrained(
Some("openai/whisper-tiny.en"),
Expand All @@ -24,6 +25,8 @@ fn main() {
&bert_model,
Some(&text_embed_config),
)
.await
.unwrap()
.unwrap();

println!("{:?}", embeddings);
Expand Down
11 changes: 11 additions & 0 deletions rust/examples/chunkers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use embed_anything::text_loader::TextLoader;
use embed_anything::chunkers::statistical::StatisticalChunker;
fn main() {
let text = TextLoader::extract_text("/home/akshay/EmbedAnything/test_files/attention.pdf").unwrap();
let chunker = StatisticalChunker{
verbose: true,
..Default::default()
};
let chunks = chunker._chunk(&text, 32);

}
2 changes: 1 addition & 1 deletion rust/examples/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() {
.unwrap()
.unwrap();

let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], &model, None).unwrap();
let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], &model, None).await.unwrap();
let n_vectors = out.len();
let out_embeddings = Tensor::from_vec(
out.iter()
Expand Down
18 changes: 15 additions & 3 deletions rust/examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{path::PathBuf, sync::Arc};

use embed_anything::{
config::TextEmbedConfig,
embed_directory_stream, embed_file,
embed_directory_stream, embed_file, embed_query,
embeddings::{
cloud::cohere::CohereEmbeder,
embed::{EmbedData, Embeder},
Expand All @@ -29,12 +29,24 @@ async fn main() -> Result<()> {
.await?
.unwrap();

let _cohere_embedding: Option<Vec<EmbedData>> = embed_file(

let _file_embedding = embed_file(
"test_files/attention.pdf",
&openai_model,
Some(&text_embed_config),
None::<fn(Vec<EmbedData>)>,
).await
?
.unwrap();

let _cohere_embedding = embed_file(
"test_files/attention.pdf",
&cohere_model,
Some(&text_embed_config),
None::<fn(Vec<EmbedData>)>,
)?;
)
.await?
.unwrap();

Ok(())
}
9 changes: 5 additions & 4 deletions rust/examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use embed_anything::{
embeddings::embed::{EmbedData, Embeder},
};

fn main() {
#[tokio::main]
async fn main() {
let start_time = std::time::Instant::now();
let url = "https://www.scrapingbee.com/blog/web-scraping-rust/".to_string();

Expand All @@ -19,9 +20,8 @@ fn main() {
&embeder,
Some(&embed_config),
None::<fn(Vec<EmbedData>)>,
)
.unwrap()
.unwrap();
).await
.unwrap().unwrap();
let embeddings: Vec<Vec<f32>> = embed_data
.iter()
.map(|data| data.embedding.clone())
Expand All @@ -37,6 +37,7 @@ fn main() {

let query = vec!["Rust for web scraping".to_string()];
let query_embedding: Vec<f32> = embed_query(query, &embeder, Some(&embed_config))
.await
.unwrap()
.iter()
.map(|data| data.embedding.clone())
Expand Down
4 changes: 3 additions & 1 deletion rust/src/chunkers/cumulative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
}
}

pub fn _chunk(&self, text: &str) {
pub async fn _chunk(&self, text: &str) {
let splits = self
.splitter
.chunks(text)
Expand Down Expand Up @@ -70,13 +70,15 @@ impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
let curr_chunk_docs_embed = self
.encoder
.embed(&[curr_chunk_docs.to_string()], Some(32))

.unwrap()
.into_iter()
.flatten()
.collect::<Vec<_>>();
let next_doc_embed = self
.encoder
.embed(&[next_doc.to_string()], Some(32))

.unwrap()
.into_iter()
.flatten()
Expand Down
Loading
Loading