Embedding string fix [COG-1900] (#742)
<!-- .github/pull_request_template.md --> ## Description Allow embedding of big strings to support full row embedding in SQL databases ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
acd7abbd29
commit
a036787ad1
2 changed files with 28 additions and 13 deletions
|
|
@ -1,7 +1,8 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
import math
|
||||
import litellm
|
||||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
|
@ -74,20 +75,34 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
return [data["embedding"] for data in response.data]
|
||||
|
||||
except litellm.exceptions.ContextWindowExceededError as error:
|
||||
if isinstance(text, list):
|
||||
if len(text) == 1:
|
||||
parts = [text]
|
||||
else:
|
||||
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]
|
||||
if isinstance(text, list) and len(text) > 1:
|
||||
mid = math.ceil(len(text) / 2)
|
||||
left, right = text[:mid], text[mid:]
|
||||
left_vecs, right_vecs = await asyncio.gather(
|
||||
self.embed_text(left),
|
||||
self.embed_text(right),
|
||||
)
|
||||
return left_vecs + right_vecs
|
||||
|
||||
parts_futures = [self.embed_text(part) for part in parts]
|
||||
embeddings = await asyncio.gather(*parts_futures)
|
||||
# If caller passed ONE oversize string split the string itself into
|
||||
# half so we can process it
|
||||
if isinstance(text, list) and len(text) == 1:
|
||||
logger.debug(f"Pooling embeddings of text string with size: {len(text[0])}")
|
||||
s = text[0]
|
||||
third = len(s) // 3
|
||||
# We are using thirds to intentionally have overlap between split parts
|
||||
# for better embedding calculation
|
||||
left_part, right_part = s[: third * 2], s[third:]
|
||||
|
||||
all_embeddings = []
|
||||
for embeddings_part in embeddings:
|
||||
all_embeddings.extend(embeddings_part)
|
||||
# Recursively embed the split parts in parallel
|
||||
(left_vec,), (right_vec,) = await asyncio.gather(
|
||||
self.embed_text([left_part]),
|
||||
self.embed_text([right_part]),
|
||||
)
|
||||
|
||||
return all_embeddings
|
||||
# POOL the two embeddings into one
|
||||
pooled = (np.array(left_vec) + np.array(right_vec)) / 2
|
||||
return [pooled.tolist()]
|
||||
|
||||
logger.error("Context window exceeded for embedding text: %s", str(error))
|
||||
raise error
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
field_name = index_name_and_field[first_occurence + 1 :]
|
||||
try:
|
||||
# In case the ammount if indexable points is too large we need to send them in batches
|
||||
batch_size = 1000
|
||||
batch_size = 100
|
||||
for i in range(0, len(indexable_points), batch_size):
|
||||
batch = indexable_points[i : i + batch_size]
|
||||
await vector_engine.index_data_points(index_name, field_name, batch)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue