refactor: Change limit=0 to limit=None in vector search. Initial commit, still wip.
This commit is contained in:
parent
9de69d2bab
commit
e3cde238ff
12 changed files with 1849 additions and 20 deletions
|
|
@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
"""
|
||||
|
|
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
"Use this option only when vector data is required."
|
||||
)
|
||||
|
||||
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
|
||||
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
|
||||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||
logger.warning(
|
||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
||||
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
|
||||
"Defaulting to limit=10.",
|
||||
limit,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
try:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
if not limit:
|
||||
limit = await collection.count()
|
||||
|
||||
# If limit is still 0, no need to do the search, just return empty results
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
results = await collection.query(
|
||||
query_embeddings=[query_vector],
|
||||
include=["metadatas", "distances", "embeddings"]
|
||||
|
|
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
for row in vector_list
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search: {str(e)}")
|
||||
logger.warning(f"Error in search: {str(e)}")
|
||||
return []
|
||||
|
||||
async def batch_search(
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -235,11 +235,11 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
if not limit:
|
||||
limit = await collection.count_rows()
|
||||
|
||||
# LanceDB search will break if limit is 0 so we must return
|
||||
if limit == 0:
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
|
@ -264,7 +264,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
|
|||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
|
@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
if not limit:
|
||||
async with self.get_async_session() as session:
|
||||
query = select(func.count()).select_from(PGVectorDataPoint)
|
||||
result = await session.execute(query)
|
||||
limit = result.rowcount
|
||||
|
||||
# If limit is still 0, no need to do the search, just return empty results
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
|
|||
collection_name: str,
|
||||
query_text: Optional[str],
|
||||
query_vector: Optional[List[float]],
|
||||
limit: int,
|
||||
limit: Optional[int],
|
||||
with_vector: bool = False,
|
||||
):
|
||||
"""
|
||||
|
|
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
|
|||
collection.
|
||||
- query_vector (Optional[List[float]]): An optional vector representation for
|
||||
searching the collection.
|
||||
- limit (int): The maximum number of results to return from the search.
|
||||
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||
- with_vector (bool): Whether to return the vector representations with search
|
||||
results. (default False)
|
||||
"""
|
||||
|
|
@ -106,7 +106,7 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
@abstractmethod
|
||||
async def batch_search(
|
||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
||||
self, collection_name: str, query_texts: List[str], limit: Optional[int], with_vectors: bool = False
|
||||
):
|
||||
"""
|
||||
Perform a batch search using multiple text queries against a collection.
|
||||
|
|
@ -116,7 +116,7 @@ class VectorDBInterface(Protocol):
|
|||
|
||||
- collection_name (str): The name of the collection to conduct the batch search in.
|
||||
- query_texts (List[str]): A list of text queries to use for the search.
|
||||
- limit (int): The maximum number of results to return for each query.
|
||||
- limit (Optional[int]): The maximum number of results to return for each query.
|
||||
- with_vectors (bool): Whether to include vector representations with search
|
||||
results. (default False)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_vector=query_vector,
|
||||
limit=0,
|
||||
limit=None,
|
||||
)
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class InsightsRetriever(BaseGraphRetriever):
|
|||
- top_k
|
||||
"""
|
||||
|
||||
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
||||
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
|
||||
"""Initialize retriever with exploration levels and search parameters."""
|
||||
self.exploration_levels = exploration_levels
|
||||
self.top_k = top_k
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||
|
||||
vector_search_results = await vector_engine.search(
|
||||
collection_name="Event_name", query_vector=query_vector, limit=0
|
||||
collection_name="Event_name", query_vector=query_vector, limit=None
|
||||
)
|
||||
|
||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ async def brute_force_triplet_search(
|
|||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=0
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
import pathlib
|
||||
import pytest
|
||||
|
||||
import cognee
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
|
||||
class TestVectorEngine:
|
||||
# Test that vector engine search works well with limit=None.
|
||||
# Search should return all triplets that exist. Used Alice for a bit larger test.
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_engine_search_none_limit(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_vector_engine_search_none_limit"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_vector_engine_search_none_limit"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
file_path = os.path.join(pathlib.Path(__file__).resolve().parent, "data", "alice_in_wonderland.txt")
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
await cognee.add(file_path)
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
query_text = "List me all the important characters in Alice in Wonderland."
|
||||
|
||||
# Use high value to make sure we get everything that the vector search returns
|
||||
retriever = GraphCompletionRetriever(top_k=1000)
|
||||
|
||||
result = await retriever.get_triplets(query_text)
|
||||
|
||||
# Check that we did not accidentally use any default value for limit in vector search along the way (like 5, 10, or 15)
|
||||
assert len(result) > 15
|
||||
|
|
@ -15,6 +15,9 @@ async def cognee_demo():
|
|||
current_directory = Path(__file__).resolve().parent.parent
|
||||
file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt")
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Call Cognee to process document
|
||||
await cognee.add(file_path)
|
||||
await cognee.cognify()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue