From 51a9ff0613395a68511bf27c4ecba2a66c8f19f9 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 14 Jan 2026 23:33:05 +0100 Subject: [PATCH] refactor: make include_payload use in vector databases a bit more readable --- .../vector/chromadb/ChromaDBAdapter.py | 2 +- .../vector/lancedb/LanceDBAdapter.py | 54 ++++----- .../vector/pgvector/PGVectorAdapter.py | 109 +++++++----------- 3 files changed, 64 insertions(+), 101 deletions(-) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 19aaa1b39..5e2d3975a 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -355,7 +355,7 @@ class ChromaDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, - include_payload: bool = False, # TODO: Add support for this parameter + include_payload: bool = False, # TODO: Add support for this parameter when set to False ): """ Search for items in a collection using either a text or a vector query. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index d27a084a2..49168ffcd 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -248,40 +248,30 @@ class LanceDBAdapter(VectorDBInterface): if limit <= 0: return [] - if include_payload: - result_values = await collection.vector_search(query_vector).limit(limit).to_list() - if not result_values: - return [] - normalized_values = normalize_distances(result_values) + # Note: Exclude payload if not needed to optimize performance + select_columns = ( + ["id", "vector", "payload", "_distance"] + if include_payload + else ["id", "vector", "_distance"] + ) + result_values = ( + await collection.vector_search(query_vector) + .select(select_columns) + .limit(limit) + .to_list() + ) + if not result_values: + return [] + normalized_values = normalize_distances(result_values) - return [ - ScoredResult( - id=parse_id(result["id"]), - payload=result["payload"], - score=normalized_values[value_index], - ) - for value_index, result in enumerate(result_values) - ] - - else: - result_values = await ( - collection.vector_search(query_vector) - .limit(limit) - .select(["id", "vector", "_distance"]) - .to_list() + return [ + ScoredResult( + id=parse_id(result["id"]), + payload=result["payload"] if include_payload else None, + score=normalized_values[value_index], ) - if not result_values: - return [] - - normalized_values = normalize_distances(result_values) - - return [ - ScoredResult( - id=parse_id(result["id"]), - score=normalized_values[value_index], - ) - for value_index, result in enumerate(result_values) - ] + for value_index, result in enumerate(result_values) + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 932e74a8c..6c73ad475 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -325,81 +325,54 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # NOTE: This needs to be initialized in case search doesn't return a value closest_items = [] - if include_payload: - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") + # Note: Exclude payload from returned columns if not needed to optimize performance + select_columns = ( + [PGVectorDataPoint] + if include_payload + else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector] + ) + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + *select_columns, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") - if limit > 0: - query = query.limit(limit) + if limit > 0: + query = query.limit(limit) - # Find closest vectors to query_vector - closest_items = await session.execute(query) + # Find closest vectors to query_vector + closest_items = await session.execute(query) - vector_list = [] + vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "payload": vector.payload, - "_distance": vector.similarity, - } - ) + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "payload": vector.payload if include_payload else None, + "_distance": vector.similarity, + } + ) - if len(vector_list) == 0: - return [] + if len(vector_list) == 0: + return [] - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] - # Create and return ScoredResult objects - return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) - for row in vector_list - ] - else: - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint.c.id, - PGVectorDataPoint.c.vector, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") - - if limit > 0: - query = query.limit(limit) - - # Find closest vectors to query_vector - closest_items = await session.execute(query) - - vector_list = [] - - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "_distance": vector.similarity, - } - ) - - if len(vector_list) == 0: - return [] - - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] - - # Create and return ScoredResult objects - return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list] + # Create and return ScoredResult objects + return [ + ScoredResult( + id=row.get("id"), + payload=row.get("payload") if include_payload else None, + score=row.get("score"), + ) + for row in vector_list + ] async def batch_search( self,