Embedding string fix [COG-1900] (#742)

<!-- .github/pull_request_template.md -->

## Description
Allow embedding of big strings to support full row embedding in SQL
databases

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Igor Ilic 2025-04-16 22:39:06 +02:00 committed by GitHub
parent acd7abbd29
commit a036787ad1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 28 additions and 13 deletions

View file

@ -1,7 +1,8 @@
import asyncio
from cognee.shared.logging_utils import get_logger
import math
from typing import List, Optional
import numpy as np
import math
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
@ -74,20 +75,34 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
if len(text) == 1:
parts = [text]
else:
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]
if isinstance(text, list) and len(text) > 1:
mid = math.ceil(len(text) / 2)
left, right = text[:mid], text[mid:]
left_vecs, right_vecs = await asyncio.gather(
self.embed_text(left),
self.embed_text(right),
)
return left_vecs + right_vecs
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
# If caller passed ONE oversize string split the string itself into
# half so we can process it
if isinstance(text, list) and len(text) == 1:
logger.debug(f"Pooling embeddings of text string with size: {len(text[0])}")
s = text[0]
third = len(s) // 3
# We are using thirds to intentionally have overlap between split parts
# for better embedding calculation
left_part, right_part = s[: third * 2], s[third:]
all_embeddings = []
for embeddings_part in embeddings:
all_embeddings.extend(embeddings_part)
# Recursively embed the split parts in parallel
(left_vec,), (right_vec,) = await asyncio.gather(
self.embed_text([left_part]),
self.embed_text([right_part]),
)
return all_embeddings
# POOL the two embeddings into one
pooled = (np.array(left_vec) + np.array(right_vec)) / 2
return [pooled.tolist()]
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error

View file

@ -39,7 +39,7 @@ async def index_data_points(data_points: list[DataPoint]):
field_name = index_name_and_field[first_occurence + 1 :]
try:
# In case the ammount if indexable points is too large we need to send them in batches
batch_size = 1000
batch_size = 100
for i in range(0, len(indexable_points), batch_size):
batch = indexable_points[i : i + batch_size]
await vector_engine.index_data_points(index_name, field_name, batch)