Skip to content

Commit

Permalink
Merge pull request StarlightSearch#65 from akshayballal95/main
Browse files Browse the repository at this point in the history
add adapter and metadata to image embedding

Former-commit-id: 06e7037
  • Loading branch information
akshayballal95 authored Aug 29, 2024
2 parents e9c85d0 + 0f71a09 commit 7cc779a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/adapters/weaviate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
7 changes: 6 additions & 1 deletion src/embeddings/local/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,15 @@ impl EmbedImage for ClipEmbeder {
.iter()
.zip(image_paths)
.map(|(data, path)| {
let mut metadata = HashMap::new();
metadata.insert(
"file_name".to_string(),
path.as_ref().to_str().unwrap().to_string(),
);
EmbedData::new(
data.to_vec(),
Some(path.as_ref().to_str().unwrap().to_string()),
None,
Some(metadata),
)
})
.collect::<Vec<_>>();
Expand Down
58 changes: 44 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod file_loader;
pub mod file_processor;
pub mod text_loader;

use std::path::PathBuf;
use std::{collections::HashMap, path::PathBuf};

use config::{AudioDecoderConfig, BertConfig, ClipConfig, CloudConfig, EmbedConfig, JinaConfig};
use embeddings::{
Expand Down Expand Up @@ -145,7 +145,7 @@ fn embed_default(file_name: &str, embeder: &str) -> PyResult<Option<Vec<EmbedDat
"Cloud"|"OpenAI" => emb_text(file_name, &Embeder::Cloud(CloudEmbeder::OpenAI(OpenAIEmbeder::default())), None, None, None),
"Jina" => emb_text(file_name, &Embeder::Jina(JinaEmbeder::default()), None, None, None),
"Bert" => emb_text(file_name, &Embeder::Bert(BertEmbeder::default()), None, None,None),
"Clip" => Ok(Some(vec![emb_image(file_name, ClipEmbeder::default())?])),
"Clip" => emb_image(file_name, ClipEmbeder::default(), None),
"Audio" => emb_audio(file_name, &config),
_ => Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI, Bert, Jina for text files and Clip for image files.",
Expand Down Expand Up @@ -282,7 +282,7 @@ pub fn embed_file(
)?
} else if let Some(clip_config) = &config.clip {
let embeder = get_clip_embeder(clip_config)?;
Some(vec![emb_image(file_name, embeder)?])
emb_image(file_name, embeder, adapter)?
} else if let Some(openai_config) = &config.cloud {
let embeder = get_cloud_embeder(openai_config)?;
let chunk_size = openai_config.chunk_size.unwrap_or(256);
Expand Down Expand Up @@ -369,7 +369,7 @@ pub fn embed_directory(
)?)
} else if let Some(clip_config) = &config.clip {
let embeder = get_clip_embeder(clip_config)?;
Ok(emb_image_directory(directory, embeder)?)
Ok(emb_image_directory(directory, embeder, adapter)?)
} else if let Some(_openai_config) = &config.cloud {
let embeder = get_cloud_embeder(_openai_config)?;
let chunk_size = _openai_config.chunk_size.unwrap_or(256);
Expand Down Expand Up @@ -409,16 +409,14 @@ pub fn embed_directory(
) ,
"Jina" => Ok(emb_directory(directory, &Embeder::Jina(JinaEmbeder::default()), extensions, None,None, adapter)?),
"Bert" => Ok(emb_directory(directory, &Embeder::Bert(BertEmbeder::default()), extensions, None,None, adapter)?),
"Clip" => Ok(emb_image_directory(directory, ClipEmbeder::default())?),
"Clip" => Ok(emb_image_directory(directory, ClipEmbeder::default(), None)?),
_ => {
Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI and Bert for text files and Clip for image files.",
))
}
}
}

// Send embeddings to vector database
}

/// Embeddings of a webpage using the specified embedding model.
Expand Down Expand Up @@ -641,9 +639,27 @@ fn emb_text<T: AsRef<std::path::Path>>(
fn emb_image<T: AsRef<std::path::Path>, U: EmbedImage>(
image_path: T,
embedding_model: U,
) -> PyResult<EmbedData> {
let embedding = embedding_model.embed_image(image_path, None).unwrap();
Ok(embedding)
adapter: Option<PyObject>,
) -> PyResult<Option<Vec<EmbedData>>> {
let mut metadata = HashMap::new();
metadata.insert(
"file_name".to_string(),
image_path.as_ref().to_str().unwrap().to_string(),
);
if let Some(adapter) = adapter {
Python::with_gil(|py| {
let embedding = vec![embedding_model.embed_image(image_path, Some(metadata)).unwrap()];
let conversion_fn = adapter.getattr(py, "convert")?;
let upsert_fn = adapter.getattr(py, "upsert")?;
let converted_embedding = conversion_fn.call1(py, (embedding,))?;
upsert_fn.call1(py, (&converted_embedding,))?;

Ok(None)
})
} else {
let embedding = embedding_model.embed_image(image_path, None).unwrap();
Ok(Some(vec![embedding]))
}
}

pub fn emb_audio<T: AsRef<std::path::Path>>(
Expand Down Expand Up @@ -693,12 +709,26 @@ pub fn emb_audio<T: AsRef<std::path::Path>>(
fn emb_image_directory<T: EmbedImage>(
directory: PathBuf,
embedding_model: T,
adapter: Option<PyObject>,
) -> PyResult<Option<Vec<EmbedData>>> {
let mut file_parser = FileParser::new();
file_parser.get_image_paths(&directory).unwrap();

let embeddings = embedding_model
.embed_image_batch(&file_parser.files)
.unwrap();
Ok(Some(embeddings))
if let Some(adapter) = adapter {
Python::with_gil(|py| {
let embeddings = embedding_model
.embed_image_batch(&file_parser.files)
.unwrap();
let conversion_fn = adapter.getattr(py, "convert")?;
let upsert_fn = adapter.getattr(py, "upsert")?;
let converted_embeddings = conversion_fn.call1(py, (embeddings,))?;
upsert_fn.call1(py, (&converted_embeddings,))?;

// return none
Ok(None)
})
} else {
let embeddings = embedding_model.embed_image_batch(&file_parser.files).unwrap();
Ok(Some(embeddings))
}
}

0 comments on commit 7cc779a

Please sign in to comment.