refactor: Change limit=0 to limit=None in vector search. Initial commit, still wip.

This commit is contained in:
Andrej Milicevic 2025-09-19 12:31:25 +02:00
parent 9de69d2bab
commit e3cde238ff
12 changed files with 1849 additions and 20 deletions

View file

@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = None, limit: Optional[int] = None,
with_vector: bool = False, with_vector: bool = False,
): ):
""" """
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
"Use this option only when vector data is required." "Use this option only when vector data is required."
) )
# In the case of excessive limit, or zero / negative value, limit will be set to 10. # In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND: if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
logger.warning( logger.warning(
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). " "Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
"Defaulting to limit=10.", "Defaulting to limit=10.",
limit, limit,
) )

View file

@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 15, limit: Optional[int] = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ):
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
try: try:
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
if limit == 0: if not limit:
limit = await collection.count() limit = await collection.count()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
results = await collection.query( results = await collection.query(
query_embeddings=[query_vector], query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"] include=["metadatas", "distances", "embeddings"]
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
for row in vector_list for row in vector_list
] ]
except Exception as e: except Exception as e:
logger.error(f"Error in search: {str(e)}") logger.warning(f"Error in search: {str(e)}")
return [] return []
async def batch_search( async def batch_search(

View file

@ -223,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 15, limit: Optional[int] = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ):
@ -235,11 +235,11 @@ class LanceDBAdapter(VectorDBInterface):
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
if limit == 0: if not limit:
limit = await collection.count_rows() limit = await collection.count_rows()
# LanceDB search will break if limit is 0 so we must return # LanceDB search will break if limit is 0 so we must return
if limit == 0: if limit <= 0:
return [] return []
results = await collection.vector_search(query_vector).limit(limit).to_pandas() results = await collection.vector_search(query_vector).limit(limit).to_pandas()
@ -264,7 +264,7 @@ class LanceDBAdapter(VectorDBInterface):
self, self,
collection_name: str, collection_name: str,
query_texts: List[str], query_texts: List[str],
limit: int = None, limit: Optional[int] = None,
with_vectors: bool = False, with_vectors: bool = False,
): ):
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)

View file

@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
collection_name: str, collection_name: str,
query_text: Optional[str] = None, query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
limit: int = 15, limit: Optional[int] = 15,
with_vector: bool = False, with_vector: bool = False,
) -> List[ScoredResult]: ) -> List[ScoredResult]:
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name) PGVectorDataPoint = await self.get_table(collection_name)
if not limit:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.rowcount
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
# NOTE: This needs to be initialized in case search doesn't return a value # NOTE: This needs to be initialized in case search doesn't return a value
closest_items = [] closest_items = []

View file

@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
collection_name: str, collection_name: str,
query_text: Optional[str], query_text: Optional[str],
query_vector: Optional[List[float]], query_vector: Optional[List[float]],
limit: int, limit: Optional[int],
with_vector: bool = False, with_vector: bool = False,
): ):
""" """
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
collection. collection.
- query_vector (Optional[List[float]]): An optional vector representation for - query_vector (Optional[List[float]]): An optional vector representation for
searching the collection. searching the collection.
- limit (int): The maximum number of results to return from the search. - limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search - with_vector (bool): Whether to return the vector representations with search
results. (default False) results. (default False)
""" """
@ -106,7 +106,7 @@ class VectorDBInterface(Protocol):
@abstractmethod @abstractmethod
async def batch_search( async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False self, collection_name: str, query_texts: List[str], limit: Optional[int], with_vectors: bool = False
): ):
""" """
Perform a batch search using multiple text queries against a collection. Perform a batch search using multiple text queries against a collection.
@ -116,7 +116,7 @@ class VectorDBInterface(Protocol):
- collection_name (str): The name of the collection to conduct the batch search in. - collection_name (str): The name of the collection to conduct the batch search in.
- query_texts (List[str]): A list of text queries to use for the search. - query_texts (List[str]): A list of text queries to use for the search.
- limit (int): The maximum number of results to return for each query. - limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search - with_vectors (bool): Whether to include vector representations with search
results. (default False) results. (default False)
""" """

View file

@ -161,7 +161,7 @@ class CogneeGraph(CogneeAbstractGraph):
edge_distances = await vector_engine.search( edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name", collection_name="EdgeType_relationship_name",
query_vector=query_vector, query_vector=query_vector,
limit=0, limit=None,
) )
projection_time = time.time() - start_time projection_time = time.time() - start_time
logger.info( logger.info(

View file

@ -25,7 +25,7 @@ class InsightsRetriever(BaseGraphRetriever):
- top_k - top_k
""" """
def __init__(self, exploration_levels: int = 1, top_k: int = 5): def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
"""Initialize retriever with exploration levels and search parameters.""" """Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels self.exploration_levels = exploration_levels
self.top_k = top_k self.top_k = top_k

View file

@ -129,7 +129,7 @@ class TemporalRetriever(GraphCompletionRetriever):
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0] query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
vector_search_results = await vector_engine.search( vector_search_results = await vector_engine.search(
collection_name="Event_name", query_vector=query_vector, limit=0 collection_name="Event_name", query_vector=query_vector, limit=None
) )
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)

View file

@ -144,7 +144,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str): async def search_in_collection(collection_name: str):
try: try:
return await vector_engine.search( return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=0 collection_name=collection_name, query_vector=query_vector, limit=None
) )
except CollectionNotFoundError: except CollectionNotFoundError:
return [] return []

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,40 @@
import os
import pathlib
import pytest
import cognee
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
class TestVectorEngine:
# Test that vector engine search works well with limit=None.
# Search should return all triplets that exist. Used Alice for a bit larger test.
@pytest.mark.asyncio
async def test_vector_engine_search_none_limit(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_vector_engine_search_none_limit"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_vector_engine_search_none_limit"
)
cognee.config.data_root_directory(data_directory_path)
file_path = os.path.join(pathlib.Path(__file__).resolve().parent, "data", "alice_in_wonderland.txt")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(file_path)
await cognee.cognify()
query_text = "List me all the important characters in Alice in Wonderland."
# Use high value to make sure we get everything that the vector search returns
retriever = GraphCompletionRetriever(top_k=1000)
result = await retriever.get_triplets(query_text)
# Check that we did not accidentally use any default value for limit in vector search along the way (like 5, 10, or 15)
assert len(result) > 15

View file

@ -15,6 +15,9 @@ async def cognee_demo():
current_directory = Path(__file__).resolve().parent.parent current_directory = Path(__file__).resolve().parent.parent
file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt") file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Call Cognee to process document # Call Cognee to process document
await cognee.add(file_path) await cognee.add(file_path)
await cognee.cognify() await cognee.cognify()