Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests #72

Merged
merged 2 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions examples/adapters/weaviate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
" return data\n",
"\n",
" def upsert(self, embeddings):\n",
" self.client.collections.get(self.index_name).data.insert_many(embeddings)\n",
" data = self.convert(embeddings)\n",
" self.client.collections.get(self.index_name).data.insert_many(data)\n",
"\n",
" def delete_index(self, index_name: str):\n",
" self.client.collections.delete(index_name)"
Expand Down Expand Up @@ -181,11 +182,14 @@
}
],
"source": [
"embed_config = embed_anything.EmbedConfig(\n",
" cloud=embed_anything.CloudConfig(provider=\"OpenAI\", chunk_size=256)\n",
"model = embed_anything.EmbeddingModel.from_pretrained_cloud(\n",
" embed_anything.WhichModel.OpenAI, model_id=\"text-embedding-3-small\"\n",
")\n",
"data = embed_anything.embed_file(\n",
" \"test.pdf\", embeder=\"OpenAI\", adapter=weaviate_adapter, config=embed_config\n",
"data = embed_anything.embed_directory(\n",
" \"test_files\",\n",
" embeder=model,\n",
" adapter=weaviate_adapter,\n",
" config=embed_anything.ImageEmbedConfig(buffer_size=100),\n",
")"
]
},
Expand Down
16 changes: 15 additions & 1 deletion python/python/embed_anything/_embed_anything.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def embed_directory(
def embed_image_directory(
file_path: str,
embeder: EmbeddingModel,
config: ImageEmbedConfig | None = None,
adapter: Adapter | None = None,
) -> list[EmbedData]:
"""
Expand All @@ -112,6 +113,7 @@ def embed_image_directory(
Args:
file_path: The path to the directory containing the images to embed.
embeder: The embedding model to use.
config: The configuration for the embedding model.
adapter: The adapter to use for storing the embeddings in a vector database.

Returns:
Expand Down Expand Up @@ -220,12 +222,24 @@ class TextEmbedConfig:
batch_size: The batch size for processing the embeddings. Default is 32. Based on the memory, you can increase or decrease the batch size.
"""

def __init__(self, chunk_size: int | None = None, batch_size: int | None = None):
def __init__(self, chunk_size: int | None = 256, batch_size: int | None = 32):
self.chunk_size = chunk_size
self.batch_size = batch_size
chunk_size: int | None
batch_size: int | None

class ImageEmbedConfig:
"""
Represents the configuration for the Image Embedding model.

Attributes:
buffer_size: The buffer size for the Image Embedding model. Default is 100.
"""

def __init__(self, buffer_size: int | None = None):
self.buffer_size = buffer_size
buffer_size: int | None

class EmbeddingModel:
"""
Represents an embedding model.
Expand Down
14 changes: 11 additions & 3 deletions python/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ pub struct TextEmbedConfig {
impl TextEmbedConfig {
#[new]
#[pyo3(signature = (chunk_size=None, batch_size=None, buffer_size=None))]
pub fn new(chunk_size: Option<usize>, batch_size: Option<usize>, buffer_size: Option<usize>) -> Self {
pub fn new(
chunk_size: Option<usize>,
batch_size: Option<usize>,
buffer_size: Option<usize>,
) -> Self {
Self {
inner: embed_anything::config::TextEmbedConfig::new(chunk_size, batch_size, buffer_size),
inner: embed_anything::config::TextEmbedConfig::new(
chunk_size,
batch_size,
buffer_size,
),
}
}

Expand Down Expand Up @@ -47,4 +55,4 @@ impl ImageEmbedConfig {
pub fn buffer_size(&self) -> Option<usize> {
self.inner.buffer_size
}
}
}
Loading
Loading