refactor: Change limit=0 to limit=None in vector search. Initial commit, still wip.
This commit is contained in:
parent
9de69d2bab
commit
e3cde238ff
12 changed files with 1849 additions and 20 deletions
|
|
@ -234,7 +234,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = None,
|
limit: Optional[int] = None,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -265,10 +265,10 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
"Use this option only when vector data is required."
|
"Use this option only when vector data is required."
|
||||||
)
|
)
|
||||||
|
|
||||||
# In the case of excessive limit, or zero / negative value, limit will be set to 10.
|
# In the case of excessive limit, or None / zero / negative value, limit will be set to 10.
|
||||||
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
if not limit or limit <= self._TOPK_LOWER_BOUND or limit > self._TOPK_UPPER_BOUND:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Provided limit (%s) is invalid (zero, negative, or exceeds maximum). "
|
"Provided limit (%s) is invalid (None, zero, negative, or exceeds maximum). "
|
||||||
"Defaulting to limit=10.",
|
"Defaulting to limit=10.",
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -352,7 +352,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: str = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: List[float] = None,
|
||||||
limit: int = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
):
|
):
|
||||||
|
|
@ -386,9 +386,13 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
try:
|
try:
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
if limit == 0:
|
if not limit:
|
||||||
limit = await collection.count()
|
limit = await collection.count()
|
||||||
|
|
||||||
|
# If limit is still 0, no need to do the search, just return empty results
|
||||||
|
if limit <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
results = await collection.query(
|
results = await collection.query(
|
||||||
query_embeddings=[query_vector],
|
query_embeddings=[query_vector],
|
||||||
include=["metadatas", "distances", "embeddings"]
|
include=["metadatas", "distances", "embeddings"]
|
||||||
|
|
@ -428,7 +432,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in search: {str(e)}")
|
logger.warning(f"Error in search: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: str = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: List[float] = None,
|
||||||
limit: int = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
):
|
):
|
||||||
|
|
@ -235,11 +235,11 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
if limit == 0:
|
if not limit:
|
||||||
limit = await collection.count_rows()
|
limit = await collection.count_rows()
|
||||||
|
|
||||||
# LanceDB search will break if limit is 0 so we must return
|
# LanceDB search will break if limit is 0 so we must return
|
||||||
if limit == 0:
|
if limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||||
|
|
@ -264,7 +264,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = None,
|
limit: Optional[int] = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy.exc import ProgrammingError
|
from sqlalchemy.exc import ProgrammingError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
) -> List[ScoredResult]:
|
) -> List[ScoredResult]:
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
|
|
@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
|
if not limit:
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
query = select(func.count()).select_from(PGVectorDataPoint)
|
||||||
|
result = await session.execute(query)
|
||||||
|
limit = result.rowcount
|
||||||
|
|
||||||
|
# If limit is still 0, no need to do the search, just return empty results
|
||||||
|
if limit <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||||
closest_items = []
|
closest_items = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: Optional[str],
|
query_text: Optional[str],
|
||||||
query_vector: Optional[List[float]],
|
query_vector: Optional[List[float]],
|
||||||
limit: int,
|
limit: Optional[int],
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
|
||||||
collection.
|
collection.
|
||||||
- query_vector (Optional[List[float]]): An optional vector representation for
|
- query_vector (Optional[List[float]]): An optional vector representation for
|
||||||
searching the collection.
|
searching the collection.
|
||||||
- limit (int): The maximum number of results to return from the search.
|
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||||
- with_vector (bool): Whether to return the vector representations with search
|
- with_vector (bool): Whether to return the vector representations with search
|
||||||
results. (default False)
|
results. (default False)
|
||||||
"""
|
"""
|
||||||
|
|
@ -106,7 +106,7 @@ class VectorDBInterface(Protocol):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
self, collection_name: str, query_texts: List[str], limit: Optional[int], with_vectors: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch search using multiple text queries against a collection.
|
Perform a batch search using multiple text queries against a collection.
|
||||||
|
|
@ -116,7 +116,7 @@ class VectorDBInterface(Protocol):
|
||||||
|
|
||||||
- collection_name (str): The name of the collection to conduct the batch search in.
|
- collection_name (str): The name of the collection to conduct the batch search in.
|
||||||
- query_texts (List[str]): A list of text queries to use for the search.
|
- query_texts (List[str]): A list of text queries to use for the search.
|
||||||
- limit (int): The maximum number of results to return for each query.
|
- limit (Optional[int]): The maximum number of results to return for each query.
|
||||||
- with_vectors (bool): Whether to include vector representations with search
|
- with_vectors (bool): Whether to include vector representations with search
|
||||||
results. (default False)
|
results. (default False)
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
edge_distances = await vector_engine.search(
|
edge_distances = await vector_engine.search(
|
||||||
collection_name="EdgeType_relationship_name",
|
collection_name="EdgeType_relationship_name",
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=0,
|
limit=None,
|
||||||
)
|
)
|
||||||
projection_time = time.time() - start_time
|
projection_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class InsightsRetriever(BaseGraphRetriever):
|
||||||
- top_k
|
- top_k
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
def __init__(self, exploration_levels: int = 1, top_k: Optional[int] = 5):
|
||||||
"""Initialize retriever with exploration levels and search parameters."""
|
"""Initialize retriever with exploration levels and search parameters."""
|
||||||
self.exploration_levels = exploration_levels
|
self.exploration_levels = exploration_levels
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||||
|
|
||||||
vector_search_results = await vector_engine.search(
|
vector_search_results = await vector_engine.search(
|
||||||
collection_name="Event_name", query_vector=query_vector, limit=0
|
collection_name="Event_name", query_vector=query_vector, limit=None
|
||||||
)
|
)
|
||||||
|
|
||||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,7 @@ async def brute_force_triplet_search(
|
||||||
async def search_in_collection(collection_name: str):
|
async def search_in_collection(collection_name: str):
|
||||||
try:
|
try:
|
||||||
return await vector_engine.search(
|
return await vector_engine.search(
|
||||||
collection_name=collection_name, query_vector=query_vector, limit=0
|
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||||
)
|
)
|
||||||
except CollectionNotFoundError:
|
except CollectionNotFoundError:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,40 @@
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorEngine:
|
||||||
|
# Test that vector engine search works well with limit=None.
|
||||||
|
# Search should return all triplets that exist. Used Alice for a bit larger test.
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vector_engine_search_none_limit(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_vector_engine_search_none_limit"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_vector_engine_search_none_limit"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
file_path = os.path.join(pathlib.Path(__file__).resolve().parent, "data", "alice_in_wonderland.txt")
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
await cognee.add(file_path)
|
||||||
|
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
query_text = "List me all the important characters in Alice in Wonderland."
|
||||||
|
|
||||||
|
# Use high value to make sure we get everything that the vector search returns
|
||||||
|
retriever = GraphCompletionRetriever(top_k=1000)
|
||||||
|
|
||||||
|
result = await retriever.get_triplets(query_text)
|
||||||
|
|
||||||
|
# Check that we did not accidentally use any default value for limit in vector search along the way (like 5, 10, or 15)
|
||||||
|
assert len(result) > 15
|
||||||
|
|
@ -15,6 +15,9 @@ async def cognee_demo():
|
||||||
current_directory = Path(__file__).resolve().parent.parent
|
current_directory = Path(__file__).resolve().parent.parent
|
||||||
file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt")
|
file_path = os.path.join(current_directory, "data", "alice_in_wonderland.txt")
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
# Call Cognee to process document
|
# Call Cognee to process document
|
||||||
await cognee.add(file_path)
|
await cognee.add(file_path)
|
||||||
await cognee.cognify()
|
await cognee.cognify()
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue