From d5a888e6c06932345fd41cd8d658bb12933b5a5b Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 14 Jan 2026 23:15:28 +0100 Subject: [PATCH] feat: make payload inclusion optional for vector search --- .../vector/chromadb/ChromaDBAdapter.py | 1 + .../vector/lancedb/LanceDBAdapter.py | 44 +++++--- .../databases/vector/models/ScoredResult.py | 6 +- .../vector/pgvector/PGVectorAdapter.py | 100 ++++++++++++------ .../databases/vector/vector_db_interface.py | 4 + .../modules/retrieval/completion_retriever.py | 4 +- 6 files changed, 112 insertions(+), 47 deletions(-) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 3380125ce..19aaa1b39 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, # TODO: Add support for this parameter ): """ Search for items in a collection using either a text or a vector query. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 6d724f9d7..d27a084a2 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, ): if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -247,21 +248,40 @@ class LanceDBAdapter(VectorDBInterface): if limit <= 0: return [] - result_values = await collection.vector_search(query_vector).limit(limit).to_list() + if include_payload: + result_values = await collection.vector_search(query_vector).limit(limit).to_list() + if not result_values: + return [] + normalized_values = normalize_distances(result_values) - if not result_values: - return [] + return [ + ScoredResult( + id=parse_id(result["id"]), + payload=result["payload"], + score=normalized_values[value_index], + ) + for value_index, result in enumerate(result_values) + ] - normalized_values = normalize_distances(result_values) - - return [ - ScoredResult( - id=parse_id(result["id"]), - payload=result["payload"], - score=normalized_values[value_index], + else: + result_values = await ( + collection.vector_search(query_vector) + .limit(limit) + .select(["id", "vector", "_distance"]) + .to_list() ) - for value_index, result in enumerate(result_values) - ] + if not result_values: + return [] + + normalized_values = normalize_distances(result_values) + + return [ + ScoredResult( + id=parse_id(result["id"]), + score=normalized_values[value_index], + ) + for value_index, result in enumerate(result_values) + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/models/ScoredResult.py b/cognee/infrastructure/databases/vector/models/ScoredResult.py index 0a8cc9888..b4792ce28 100644 --- a/cognee/infrastructure/databases/vector/models/ScoredResult.py +++ b/cognee/infrastructure/databases/vector/models/ScoredResult.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from uuid import UUID from pydantic import BaseModel @@ -12,10 +12,10 @@ class ScoredResult(BaseModel): - id (UUID): Unique identifier for the scored result. - score (float): The score associated with the result, where a lower score indicates a better outcome. - - payload (Dict[str, Any]): Additional information related to the score, stored as + - payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as key-value pairs in a dictionary. """ id: UUID score: float # Lower score is better - payload: Dict[str, Any] + payload: Optional[Dict[str, Any]] = None diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 1986fae48..932e74a8c 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vector: Optional[List[float]] = None, limit: Optional[int] = 15, with_vector: bool = False, + include_payload: bool = False, ) -> List[ScoredResult]: if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -324,44 +325,81 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # NOTE: This needs to be initialized in case search doesn't return a value closest_items = [] - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") + if include_payload: + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + PGVectorDataPoint, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") - if limit > 0: - query = query.limit(limit) + if limit > 0: + query = query.limit(limit) - # Find closest vectors to query_vector - closest_items = await session.execute(query) + # Find closest vectors to query_vector + closest_items = await session.execute(query) - vector_list = [] + vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "payload": vector.payload, - "_distance": vector.similarity, - } - ) + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "payload": vector.payload, + "_distance": vector.similarity, + } + ) - if len(vector_list) == 0: - return [] + if len(vector_list) == 0: + return [] - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] - # Create and return ScoredResult objects - return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) - for row in vector_list - ] + # Create and return ScoredResult objects + return [ + ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) + for row in vector_list + ] + else: + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + PGVectorDataPoint.c.id, + PGVectorDataPoint.c.vector, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") + + if limit > 0: + query = query.limit(limit) + + # Find closest vectors to query_vector + closest_items = await session.execute(query) + + vector_list = [] + + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "_distance": vector.similarity, + } + ) + + if len(vector_list) == 0: + return [] + + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] + + # Create and return ScoredResult objects + return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 12ace1a6c..e38de19e2 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -87,6 +87,7 @@ class VectorDBInterface(Protocol): query_vector: Optional[List[float]], limit: Optional[int], with_vector: bool = False, + include_payload: bool = False, ): """ Perform a search in the specified collection using either a text query or a vector @@ -103,6 +104,9 @@ class VectorDBInterface(Protocol): - limit (Optional[int]): The maximum number of results to return from the search. - with_vector (bool): Whether to return the vector representations with search results. (default False) + - include_payload (bool): Whether to include the payload data with search. Search is faster when set to False. + Payload contains metadata about the data point, useful for searches that are only based on embedding distances + like the RAG_COMPLETION search type, but not needed when search also contains graph data. """ raise NotImplementedError diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 0e9a4167c..f3a7bd505 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever): vector_engine = get_vector_engine() try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search( + "DocumentChunk_text", query, limit=self.top_k, include_payload=True + ) if len(found_chunks) == 0: return ""