Skip to content

Commit

Permalink
Merge pull request StarlightSearch#75 from akshayballal95/main
Browse files Browse the repository at this point in the history
Remove TextEmbed

Former-commit-id: 446d095
  • Loading branch information
akshayballal95 authored Sep 11, 2024
2 parents 7c40997 + 253ea00 commit 26ea4b7
Show file tree
Hide file tree
Showing 15 changed files with 199 additions and 91 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions docs/blog/.authors.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
authors:
akshay:
name: Akshay Ballal
description: Creator of EmbedAnything
avatar: https://pbs.twimg.com/profile_images/1660187462357127168/6dV9SpLi_400x400.jpg
sonam:
name: Sonam Pankaj
description: Creator of EmbedAnything
avatar: https://pbs.twimg.com/profile_images/1798985783292125184/L6YQmg1Q_400x400.jpg
1 change: 1 addition & 0 deletions docs/blog/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# All Posts
150 changes: 150 additions & 0 deletions docs/blog/posts/vector-streaming.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
---
draft: false
date: 2024-01-31
authors:
- akshay
- sonam
slug: vector-streaming
---
Introducing vector streaming in EmbedAnything, a feature designed to optimize large-scale document embedding. By enabling asynchronous chunking and embedding using Rust’s concurrency, it reduces memory usage and speeds up the process. We also show how to integrate it with the Weaviate Vector Database for seamless image embedding and search.


<!-- more -->

In my previous article [Supercharge Your Embeddings Pipeline with EmbedAnything](https://www.analyticsvidhya.com/blog/2024/06/supercharge-your-embeddings-pipeline-with-embedanything/), I discussed the idea behind EmbedAnything and how it makes creating embeddings from multiple modalities easy. In this article, I want to introduce a new feature of EmbedAnything called vector streaming and see how it works with Weaviate Vector Database.

### What is the problem?

First, let's examine the current problem with creating embeddings, especially in large-scale documents. The current embedding frameworks operate on a two-step process: chunking and embedding. First, the text is extracted from all the files, and chunks/nodes are created. Then, these chunks are fed to an embedding model with a specific batch size to process the embeddings. While this is done, the chunks and the embeddings stay on the system memory. This is not a problem when the files are small, and the embedding dimensions are small. But this becomes a problem when there are many files and you are working with large models and, even worse, multi-vector embeddings. Thus, to work with this, a high RAM is required to process the embeddings. Also, if this is done synchronously, a lot of time is wasted while the chunks are being created, as chunking is not a compute-heavy operation. As the chunks are being made, passing them to the embedding model would be efficient.

### Our Solution

The solution is to create an asynchronous chunking and embedding task. We can effectively spawn threads to handle this task using Rust's concurrency patterns and thread safety. This is done using Rust's MPSC (Multi-producer Single Consumer) module, which passes messages between threads. Thus, this creates a stream of chunks passed into the embedding thread with a buffer. Once the buffer is complete, it embeds the chunks and sends the embeddings back to the main thread, where they are sent to the vector database. This ensures no time is wasted on a single operation and no bottlenecks. Moreover, only the chunks and embeddings in the buffer are stored in the system memory. They are erased from the memory once moved to the vector database.


![Vector Streaming](https://res.cloudinary.com/dltwftrgc/image/upload/v1726073108/vector_streaming_m6xa1j.png)



### Example Use Case

Now, let's see this feature in action.

With EmbedAnything, streaming the vectors from a directory of files to the vector database is a simple three-step process.

1. **Create an adapter for your vector database:** This is a wrapper around the database's functions that allows you to create an index, convert metadata from EmbedAnything's format to the format required by the database, and the function to insert the embeddings in the index. Adapters for the prominent databases are already created and present [here](https://github.com/StarlightSearch/EmbedAnything/tree/main/examples/adapters):

2. **Initiate an embedding model of your choice:** You can choose from different local models or even cloud models. The configuration can also be determined to set the chunk size and buffer size for how many embeddings need to be streamed at once. Ideally, this should be as high as possible, but the system RAM limits this.

3. **Call the embedding function from EmbedAnything:** Just pass the directory path to be embedded, the embedding model, the adapter, and the configuration.

In this example, we will embed a directory of images and send it to the vector databases.

#### Step 1: Create the Adapter

In EmbedAnything, the adapters are created outside so as to not make the library heavy and you get to choose which database you want to work with. Here is a simple adapter for Weaviate.

```python
from embed_anything import EmbedData
from embed_anything.vectordb import Adapter

class WeaviateAdapter(Adapter):
def __init__(self, api_key, url):
super().__init__(api_key)
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=url, auth_credentials=wvc.init.Auth.api_key(api_key)
)
if self.client.is_ready():
print("Weaviate is ready")

def create_index(self, index_name: str):
self.index_name = index_name
self.collection = self.client.collections.create(
index_name, vectorizer_config=wvc.config.Configure.Vectorizer.none()
)
return self.collection

def convert(self, embeddings: List[EmbedData]):
data = []
for embedding in embeddings:
property = embedding.metadata
property["text"] = embedding.text
data.append(
wvc.data.DataObject(properties=property, vector=embedding.embedding)
)
return data

def upsert(self, embeddings):
data = self.convert(embeddings)
self.client.collections.get(self.index_name).data.insert_many(data)

def delete_index(self, index_name: str):
self.client.collections.delete(index_name)

### Start the client and index

URL = "your-weaviate-url"
API_KEY = "your-weaviate-api-key"
weaviate_adapter = WeaviateAdapter(API_KEY, URL)

index_name = "Test_index"
if index_name in weaviate_adapter.client.collections.list_all():
weaviate_adapter.delete_index(index_name)
weaviate_adapter.create_index("Test_index")
```


#### Step 2: Create the Embedding Model

Here, since we are embedding images, we can use the clip model

```python
import embed_anything import WhichModel

model = embed_anything.EmbeddingModel.from_pretrained_cloud(
embed_anything.WhichModel.Clip,
model_id="openai/clip-vit-base-patch16")

```

#### Step 3: Embed the Directory

```python

data = embed_anything.embed_image_directory(
"\image_directory",
embeder=model,
adapter=weaviate_adapter,
config=embed_anything.ImageEmbedConfig(buffer_size=100),
)

```

#### Step 4: Query the Vector Database

```python
query_vector = embed_anything.embed_query(["image of a cat"], embeder=model)[0].embedding
```

#### Step 5: Query the Vector Database

```python
response = weaviate_adapter.collection.query.near_vector(
near_vector=query_vector,
limit=2,
return_metadata=wvc.query.MetadataQuery(certainty=True),
)
```

Check the response;


![Output](https://res.cloudinary.com/dltwftrgc/image/upload/v1726073341/Blogs/Vector%20Streaming/output_2_zsjg87.png)

Check out the notebook for the code here on colab

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17vUZEh-ZSpN339pIXSkyxtDHS5Sz6DqD?usp=sharing)

### Conclusion

We think vector streaming is one of the features that will empower many engineers to opt for a more optimized and no-tech debt solution. Instead of using bulky frameworks on the cloud, you can use a lightweight streaming option. Please don't forget to give us a ⭐ on our GitHub repo over here: [EmbedAnything Repo](https://github.com/StarlightSearch/EmbedAnything)
5 changes: 5 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ theme:
name: "material"
plugins:
- mkdocstrings
- blog:
archive: false

nav:
- index.md
- references.md
- Blog:
- blog/index.md


markdown_extensions:
- pymdownx.highlight:
Expand Down
6 changes: 4 additions & 2 deletions rust/examples/chunkers.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use embed_anything::text_loader::TextLoader;
use embed_anything::chunkers::statistical::StatisticalChunker;
fn main() {

#[tokio::main]
async fn main() {
let text = TextLoader::extract_text("/home/akshay/EmbedAnything/test_files/attention.pdf").unwrap();
let chunker = StatisticalChunker{
verbose: true,
..Default::default()
};
let chunks = chunker._chunk(&text, 32);
let chunks = chunker._chunk(&text, 32).await;

}
19 changes: 9 additions & 10 deletions rust/src/chunkers/cumulative.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
use crate::embeddings::{embed::TextEmbed, local::jina::JinaEmbeder};
use crate::embeddings::{embed::Embeder, local::jina::JinaEmbeder};
use candle_core::Tensor;
use text_splitter::{ChunkConfig, ChunkSizer, TextSplitter};
use tokenizers::Tokenizer;

#[derive(Debug)]
pub struct CumulativeChunker<T: TextEmbed, Sizer: ChunkSizer> {
pub encoder: T,
pub struct CumulativeChunker<Sizer: ChunkSizer> {
pub encoder: Embeder,
pub splitter: TextSplitter<Sizer>,
pub score_threshold: f32,
pub device: candle_core::Device,
}

impl Default for CumulativeChunker<JinaEmbeder, Tokenizer> {
impl Default for CumulativeChunker<Tokenizer> {
fn default() -> Self {
let splitter = TextSplitter::new(ChunkConfig::new(200).with_sizer(
Tokenizer::from_pretrained("BEE-spoke-data/cl100k_base-mlm", None).unwrap(),
));
let encoder = JinaEmbeder::default();
let encoder = Embeder::Jina(JinaEmbeder::default());
let score_threshold = 0.9;
let device = candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu);
Self {
Expand All @@ -28,8 +27,8 @@ impl Default for CumulativeChunker<JinaEmbeder, Tokenizer> {
}
}

impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
pub fn new(encoder: T, splitter: TextSplitter<Sizer>, score_threshold: f32) -> Self {
impl<Sizer: ChunkSizer> CumulativeChunker<Sizer> {
pub fn new(encoder: Embeder, splitter: TextSplitter<Sizer>, score_threshold: f32) -> Self {
Self {
encoder,
splitter,
Expand Down Expand Up @@ -70,15 +69,15 @@ impl<T: TextEmbed, Sizer: ChunkSizer> CumulativeChunker<T, Sizer> {
let curr_chunk_docs_embed = self
.encoder
.embed(&[curr_chunk_docs.to_string()], Some(32))

.await
.unwrap()
.into_iter()
.flatten()
.collect::<Vec<_>>();
let next_doc_embed = self
.encoder
.embed(&[next_doc.to_string()], Some(32))

.await
.unwrap()
.into_iter()
.flatten()
Expand Down
25 changes: 12 additions & 13 deletions rust/src/chunkers/statistical.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::cmp::max;

use crate::embeddings::{embed::TextEmbed, local::jina::JinaEmbeder};
use crate::embeddings::{embed::Embeder, local::jina::JinaEmbeder};
use candle_core::Tensor;
use itertools::{enumerate, Itertools};
use tokenizers::Tokenizer;

#[derive(Debug)]
pub struct StatisticalChunker<T: TextEmbed> {
pub encoder: T,
pub struct StatisticalChunker {
pub encoder: Embeder,
pub device: candle_core::Device,
pub threshold_adjustment: f32,
pub dynamic_threshold: bool,
Expand All @@ -18,10 +17,10 @@ pub struct StatisticalChunker<T: TextEmbed> {
pub tokenizer: Tokenizer,
pub verbose: bool,
}
impl Default for StatisticalChunker<JinaEmbeder> {
impl Default for StatisticalChunker {
fn default() -> Self {
let tokenizer = Tokenizer::from_pretrained("BEE-spoke-data/cl100k_base-mlm", None).unwrap();
let encoder = JinaEmbeder::default();
let encoder = Embeder::Jina(JinaEmbeder::default());
let device = candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu);
Self {
encoder,
Expand All @@ -38,10 +37,10 @@ impl Default for StatisticalChunker<JinaEmbeder> {
}
}

impl<T: TextEmbed> StatisticalChunker<T> {
impl StatisticalChunker {
#[allow(clippy::too_many_arguments)]
pub fn new(
encoder: T,
encoder: Embeder,
threshold_adjustment: f32,
dynamic_threshold: bool,
window_size: usize,
Expand Down Expand Up @@ -101,7 +100,7 @@ impl<T: TextEmbed> StatisticalChunker<T> {
Some(chunks)
}

pub fn _chunk(&self, text: &str, batch_size: usize) -> Vec<String> {
pub async fn _chunk(&self, text: &str, batch_size: usize) -> Vec<String> {
let splits = self.split_into_sentences(text, 50).unwrap();

if self.verbose {
Expand All @@ -126,7 +125,7 @@ impl<T: TextEmbed> StatisticalChunker<T> {
.collect::<Vec<_>>();
}

let encoded_splits = self.encoder.embed(&batch_splits, Some(16)).unwrap();
let encoded_splits = self.encoder.embed(&batch_splits, Some(16)).await.unwrap();
let similarities = self._calculate_similarity_scores(&encoded_splits);
let calculated_threshold = self._find_optimal_threshold(&batch_splits, &similarities);

Expand Down Expand Up @@ -331,15 +330,15 @@ mod tests {

use super::*;

#[test]
fn test_statistical_chunker() {
#[tokio::test]
async fn test_statistical_chunker() {
let text = TextLoader::extract_text("/home/akshay/EmbedAnything/test_files/attention.pdf").unwrap();
let chunker = StatisticalChunker{
verbose: true,
..Default::default()
};
println!("-----Text---\n{}", text);
let chunks = chunker._chunk(&text, 10);
let chunks = chunker._chunk(&text, 10).await;
assert_eq!(chunks.len(), 1);
}
}
11 changes: 0 additions & 11 deletions rust/src/embeddings/cloud/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use reqwest::Client;
use serde::Deserialize;
use serde_json::json;

use crate::embeddings::embed::TextEmbed;

/// Represents the response from the Cohere embedding API.
#[derive(Deserialize, Debug, Default)]
pub struct CohereEmbedResponse {
Expand Down Expand Up @@ -31,15 +29,6 @@ impl Default for CohereEmbeder {
}
}

impl TextEmbed for CohereEmbeder {
fn embed(
&self,
text_batch: &[String],
_batch_size: Option<usize>,
) -> Result<Vec<Vec<f32>>, anyhow::Error> {
tokio::runtime::Runtime::new()?.block_on(self.embed(text_batch))
}
}

impl CohereEmbeder {
/// Creates a new instance of `CohereEmbeder` with the specified model and API key.
Expand Down
Loading

0 comments on commit 26ea4b7

Please sign in to comment.