refactor: Change limit=0 to limit=None in vector search. Initial commit, still wip.

This commit is contained in:
Andrej Milicevic 2025-09-19 12:31:25 +02:00
parent 9de69d2bab
commit e3cde238ff
12 changed files with 1849 additions and 20 deletions

View file

@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = None,
limit: Optional[int] = None,
with_vector: bool = False,
):
"""
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
"Use this option only when vector data is required."
)
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
"Defaulting to limit=10.",
limit,
)

View file

@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
try:
collection = await self.get_collection(collection_name)
if limit == 0:
if not limit:
limit = await collection.count()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"]
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
logger.warning(f"Error in search: {str(e)}")
return []
async def batch_search(

View file

@ -223,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
):
@ -235,11 +235,11 @@ class LanceDBAdapter(VectorDBInterface):
collection = await self.get_collection(collection_name)
if limit == 0:
if not limit:
limit = await collection.count_rows()
# LanceDB search will break if limit is 0 so we must return
if limit == 0:
if limit <= 0:
return []
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
@ -264,7 +264,7 @@ class LanceDBAdapter(VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)

View file

@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
limit: Optional[int] = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
if not limit:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.rowcount
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []

View file

@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
collection_name: str,
query_text: Optional[str],
query_vector: Optional[List[float]],
limit: int,
limit: Optional[int],
with_vector: bool = False,
):
"""
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
collection.
- query_vector (Optional[List[float]]): An optional vector representation for
searching the collection.
- limit (int): The maximum number of results to return from the search.
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
"""
@ -106,7 +106,7 @@ class VectorDBInterface(Protocol):
@abstractmethod
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self, collection_name: str, query_texts: List[str], limit: Optional[int], with_vectors: bool = False
):
"""
Perform a batch search using multiple text queries against a collection.
@ -116,7 +116,7 @@ class VectorDBInterface(Protocol):
- collection_name (str): The name of the collection to conduct the batch search in.
- query_texts (List[str]): A list of text queries to use for the search.
- limit (int): The maximum number of results to return for each query.
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
"""

View file

@ -161,7 +161,7 @@ class CogneeGraph(CogneeAbstractGraph):
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_vector=query_vector,
limit=0,
limit=None,
)
projection_time = time.time() - start_time
logger.info(

View file

@ -25,7 +25,7 @@ class InsightsRetriever(BaseGraphRetriever):
- top_k
"""
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k

View file

@ -129,7 +129,7 @@ class TemporalRetriever(GraphCompletionRetriever):
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
vector_search_results = await vector_engine.search(
collection_name="Event_name", query_vector=query_vector, limit=0
collection_name="Event_name", query_vector=query_vector, limit=None
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)

View file

@ -144,7 +144,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=0
collection_name=collection_name, query_vector=query_vector, limit=None
)
except CollectionNotFoundError:
return []

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,40 @@
import os
import pathlib
import pytest
import cognee
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
class TestVectorEngine:
# Test that vector engine search works well with limit=None.
# Search should return all triplets that exist. Used Alice for a bit larger test.
@pytest.mark.asyncio
async def test_vector_engine_search_none_limit(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_vector_engine_search_none_limit"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_vector_engine_search_none_limit"
)
cognee.config.data_root_directory(data_directory_path)
file_path = os.path.join(pathlib.Path(__file__).resolve().parent, "data", "alice_in_wonderland.txt")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(file_path)
await cognee.cognify()
query_text = "List me all the important characters in Alice in Wonderland."
# Use high value to make sure we get everything that the vector search returns
retriever = GraphCompletionRetriever(top_k=1000)
result = await retriever.get_triplets(query_text)
# Check that we did not accidentally use any default value for limit in vector search along the way (like 5, 10, or 15)
assert len(result) > 15

View file

@ -15,6 +15,9 @@ async def cognee_demo():
current_directory = Path(__file__).resolve().parent.parent
file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Call Cognee to process document
await cognee.add(file_path)
await cognee.cognify()