From d5a888e6c06932345fd41cd8d658bb12933b5a5b Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 14 Jan 2026 23:15:28 +0100 Subject: [PATCH 01/11] feat: make payload inclusion optional for vector search --- .../vector/chromadb/ChromaDBAdapter.py | 1 + .../vector/lancedb/LanceDBAdapter.py | 44 +++++--- .../databases/vector/models/ScoredResult.py | 6 +- .../vector/pgvector/PGVectorAdapter.py | 100 ++++++++++++------ .../databases/vector/vector_db_interface.py | 4 + .../modules/retrieval/completion_retriever.py | 4 +- 6 files changed, 112 insertions(+), 47 deletions(-) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 3380125ce..19aaa1b39 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, # TODO: Add support for this parameter ): """ Search for items in a collection using either a text or a vector query. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 6d724f9d7..d27a084a2 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, + include_payload: bool = False, ): if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -247,21 +248,40 @@ class LanceDBAdapter(VectorDBInterface): if limit <= 0: return [] - result_values = await collection.vector_search(query_vector).limit(limit).to_list() + if include_payload: + result_values = await collection.vector_search(query_vector).limit(limit).to_list() + if not result_values: + return [] + normalized_values = normalize_distances(result_values) - if not result_values: - return [] + return [ + ScoredResult( + id=parse_id(result["id"]), + payload=result["payload"], + score=normalized_values[value_index], + ) + for value_index, result in enumerate(result_values) + ] - normalized_values = normalize_distances(result_values) - - return [ - ScoredResult( - id=parse_id(result["id"]), - payload=result["payload"], - score=normalized_values[value_index], + else: + result_values = await ( + collection.vector_search(query_vector) + .limit(limit) + .select(["id", "vector", "_distance"]) + .to_list() ) - for value_index, result in enumerate(result_values) - ] + if not result_values: + return [] + + normalized_values = normalize_distances(result_values) + + return [ + ScoredResult( + id=parse_id(result["id"]), + score=normalized_values[value_index], + ) + for value_index, result in enumerate(result_values) + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/models/ScoredResult.py b/cognee/infrastructure/databases/vector/models/ScoredResult.py index 0a8cc9888..b4792ce28 100644 --- a/cognee/infrastructure/databases/vector/models/ScoredResult.py +++ b/cognee/infrastructure/databases/vector/models/ScoredResult.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from uuid import UUID from pydantic import BaseModel @@ -12,10 +12,10 @@ class ScoredResult(BaseModel): - id (UUID): Unique identifier for the scored result. - score (float): The score associated with the result, where a lower score indicates a better outcome. - - payload (Dict[str, Any]): Additional information related to the score, stored as + - payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as key-value pairs in a dictionary. """ id: UUID score: float # Lower score is better - payload: Dict[str, Any] + payload: Optional[Dict[str, Any]] = None diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 1986fae48..932e74a8c 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vector: Optional[List[float]] = None, limit: Optional[int] = 15, with_vector: bool = False, + include_payload: bool = False, ) -> List[ScoredResult]: if query_text is None and query_vector is None: raise MissingQueryParameterError() @@ -324,44 +325,81 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # NOTE: This needs to be initialized in case search doesn't return a value closest_items = [] - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") + if include_payload: + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + PGVectorDataPoint, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") - if limit > 0: - query = query.limit(limit) + if limit > 0: + query = query.limit(limit) - # Find closest vectors to query_vector - closest_items = await session.execute(query) + # Find closest vectors to query_vector + closest_items = await session.execute(query) - vector_list = [] + vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "payload": vector.payload, - "_distance": vector.similarity, - } - ) + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "payload": vector.payload, + "_distance": vector.similarity, + } + ) - if len(vector_list) == 0: - return [] + if len(vector_list) == 0: + return [] - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] - # Create and return ScoredResult objects - return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) - for row in vector_list - ] + # Create and return ScoredResult objects + return [ + ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) + for row in vector_list + ] + else: + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + PGVectorDataPoint.c.id, + PGVectorDataPoint.c.vector, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") + + if limit > 0: + query = query.limit(limit) + + # Find closest vectors to query_vector + closest_items = await session.execute(query) + + vector_list = [] + + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "_distance": vector.similarity, + } + ) + + if len(vector_list) == 0: + return [] + + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] + + # Create and return ScoredResult objects + return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 12ace1a6c..e38de19e2 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -87,6 +87,7 @@ class VectorDBInterface(Protocol): query_vector: Optional[List[float]], limit: Optional[int], with_vector: bool = False, + include_payload: bool = False, ): """ Perform a search in the specified collection using either a text query or a vector @@ -103,6 +104,9 @@ class VectorDBInterface(Protocol): - limit (Optional[int]): The maximum number of results to return from the search. - with_vector (bool): Whether to return the vector representations with search results. (default False) + - include_payload (bool): Whether to include the payload data with search. Search is faster when set to False. + Payload contains metadata about the data point, useful for searches that are only based on embedding distances + like the RAG_COMPLETION search type, but not needed when search also contains graph data. """ raise NotImplementedError diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 0e9a4167c..f3a7bd505 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever): vector_engine = get_vector_engine() try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search( + "DocumentChunk_text", query, limit=self.top_k, include_payload=True + ) if len(found_chunks) == 0: return "" From 51a9ff0613395a68511bf27c4ecba2a66c8f19f9 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 14 Jan 2026 23:33:05 +0100 Subject: [PATCH 02/11] refactor: make include_payload use in vector databases a bit more readable --- .../vector/chromadb/ChromaDBAdapter.py | 2 +- .../vector/lancedb/LanceDBAdapter.py | 54 ++++----- .../vector/pgvector/PGVectorAdapter.py | 109 +++++++----------- 3 files changed, 64 insertions(+), 101 deletions(-) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 19aaa1b39..5e2d3975a 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -355,7 +355,7 @@ class ChromaDBAdapter(VectorDBInterface): limit: Optional[int] = 15, with_vector: bool = False, normalized: bool = True, - include_payload: bool = False, # TODO: Add support for this parameter + include_payload: bool = False, # TODO: Add support for this parameter when set to False ): """ Search for items in a collection using either a text or a vector query. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index d27a084a2..49168ffcd 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -248,40 +248,30 @@ class LanceDBAdapter(VectorDBInterface): if limit <= 0: return [] - if include_payload: - result_values = await collection.vector_search(query_vector).limit(limit).to_list() - if not result_values: - return [] - normalized_values = normalize_distances(result_values) + # Note: Exclude payload if not needed to optimize performance + select_columns = ( + ["id", "vector", "payload", "_distance"] + if include_payload + else ["id", "vector", "_distance"] + ) + result_values = ( + await collection.vector_search(query_vector) + .select(select_columns) + .limit(limit) + .to_list() + ) + if not result_values: + return [] + normalized_values = normalize_distances(result_values) - return [ - ScoredResult( - id=parse_id(result["id"]), - payload=result["payload"], - score=normalized_values[value_index], - ) - for value_index, result in enumerate(result_values) - ] - - else: - result_values = await ( - collection.vector_search(query_vector) - .limit(limit) - .select(["id", "vector", "_distance"]) - .to_list() + return [ + ScoredResult( + id=parse_id(result["id"]), + payload=result["payload"] if include_payload else None, + score=normalized_values[value_index], ) - if not result_values: - return [] - - normalized_values = normalize_distances(result_values) - - return [ - ScoredResult( - id=parse_id(result["id"]), - score=normalized_values[value_index], - ) - for value_index, result in enumerate(result_values) - ] + for value_index, result in enumerate(result_values) + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 932e74a8c..6c73ad475 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -325,81 +325,54 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # NOTE: This needs to be initialized in case search doesn't return a value closest_items = [] - if include_payload: - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") + # Note: Exclude payload from returned columns if not needed to optimize performance + select_columns = ( + [PGVectorDataPoint] + if include_payload + else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector] + ) + # Use async session to connect to the database + async with self.get_async_session() as session: + query = select( + *select_columns, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), + ).order_by("similarity") - if limit > 0: - query = query.limit(limit) + if limit > 0: + query = query.limit(limit) - # Find closest vectors to query_vector - closest_items = await session.execute(query) + # Find closest vectors to query_vector + closest_items = await session.execute(query) - vector_list = [] + vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "payload": vector.payload, - "_distance": vector.similarity, - } - ) + # Extract distances and find min/max for normalization + for vector in closest_items.all(): + vector_list.append( + { + "id": parse_id(str(vector.id)), + "payload": vector.payload if include_payload else None, + "_distance": vector.similarity, + } + ) - if len(vector_list) == 0: - return [] + if len(vector_list) == 0: + return [] - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] + # Normalize vector distance and add this as score information to vector_list + normalized_values = normalize_distances(vector_list) + for i in range(0, len(normalized_values)): + vector_list[i]["score"] = normalized_values[i] - # Create and return ScoredResult objects - return [ - ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) - for row in vector_list - ] - else: - # Use async session to connect to the database - async with self.get_async_session() as session: - query = select( - PGVectorDataPoint.c.id, - PGVectorDataPoint.c.vector, - PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"), - ).order_by("similarity") - - if limit > 0: - query = query.limit(limit) - - # Find closest vectors to query_vector - closest_items = await session.execute(query) - - vector_list = [] - - # Extract distances and find min/max for normalization - for vector in closest_items.all(): - vector_list.append( - { - "id": parse_id(str(vector.id)), - "_distance": vector.similarity, - } - ) - - if len(vector_list) == 0: - return [] - - # Normalize vector distance and add this as score information to vector_list - normalized_values = normalize_distances(vector_list) - for i in range(0, len(normalized_values)): - vector_list[i]["score"] = normalized_values[i] - - # Create and return ScoredResult objects - return [ScoredResult(id=row.get("id"), score=row.get("score")) for row in vector_list] + # Create and return ScoredResult objects + return [ + ScoredResult( + id=row.get("id"), + payload=row.get("payload") if include_payload else None, + score=row.get("score"), + ) + for row in vector_list + ] async def batch_search( self, From 4f7ab8768376ad657ec1b3c782ec2f963c57e30a Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 15 Jan 2026 11:32:04 +0100 Subject: [PATCH 03/11] refactor: Use include_payload where necessary --- .../hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py | 1 + cognee/modules/retrieval/chunks_retriever.py | 4 +++- cognee/modules/retrieval/triplet_retriever.py | 4 +++- cognee/tests/test_chromadb.py | 2 +- cognee/tests/test_library.py | 2 +- cognee/tests/test_neo4j.py | 4 +++- cognee/tests/test_neptune_analytics_vector.py | 4 +++- cognee/tests/test_pgvector.py | 4 +++- cognee/tests/test_remote_kuzu.py | 4 +++- cognee/tests/test_s3_file_storage.py | 2 +- 10 files changed, 22 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py index 72a1fac01..a8df0b2e0 100644 --- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): query_vector: Optional[List[float]] = None, limit: Optional[int] = None, with_vector: bool = False, + include_payload: bool = False, # TODO: Add support for this parameter ): """ Perform a search in the specified collection using either a text query or a vector diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index ce9b8233b..1a31087d6 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever): vector_engine = get_vector_engine() try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search( + "DocumentChunk_text", query, limit=self.top_k, include_payload=True + ) logger.info(f"Found {len(found_chunks)} chunks from vector search") await update_node_access_timestamps(found_chunks) diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py index b9d006312..ece9c6f85 100644 --- a/cognee/modules/retrieval/triplet_retriever.py +++ b/cognee/modules/retrieval/triplet_retriever.py @@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever): "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. " ) - found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k) + found_triplets = await vector_engine.search( + "Triplet_text", query, limit=self.top_k, include_payload=True + ) if len(found_triplets) == 0: return "" diff --git a/cognee/tests/test_chromadb.py b/cognee/tests/test_chromadb.py index 767edf3dc..b5d1c4675 100644 --- a/cognee/tests/test_chromadb.py +++ b/cognee/tests/test_chromadb.py @@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit(): query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0] result = await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True ) # Check that we did not accidentally use any default value for limit diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 893b836c0..403bb9e29 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -48,7 +48,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 925614e67..6cc2d7fec 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -63,7 +63,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_neptune_analytics_vector.py b/cognee/tests/test_neptune_analytics_vector.py index 99c4d94b4..d86dd6a63 100644 --- a/cognee/tests/test_neptune_analytics_vector.py +++ b/cognee/tests/test_neptune_analytics_vector.py @@ -52,7 +52,9 @@ async def main(): await cognee.cognify([dataset_name]) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 240f9e9bb..8e4b3e8f0 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -163,7 +163,9 @@ async def main(): await test_getting_of_documents(dataset_name_1) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_remote_kuzu.py b/cognee/tests/test_remote_kuzu.py index 1c619719c..cea5be904 100644 --- a/cognee/tests/test_remote_kuzu.py +++ b/cognee/tests/test_remote_kuzu.py @@ -58,7 +58,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_s3_file_storage.py b/cognee/tests/test_s3_file_storage.py index c7fc62cf2..eeb372753 100755 --- a/cognee/tests/test_s3_file_storage.py +++ b/cognee/tests/test_s3_file_storage.py @@ -43,7 +43,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( From e65a6a16792680a19c9373ac588afcb551dc9693 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 15 Jan 2026 11:43:19 +0100 Subject: [PATCH 04/11] refactor: include payload for summaries retriever --- cognee/modules/retrieval/summaries_retriever.py | 2 +- .../modules/retrieval/rag_completion_retriever_test.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 13972bb8d..e79bb514d 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever): try: summaries_results = await vector_engine.search( - "TextSummary_text", query, limit=self.top_k + "TextSummary_text", query, limit=self.top_k, include_payload=True ) logger.info(f"Found {len(summaries_results)} summaries from vector search") diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index e998d419d..4a73ef380 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Steve Rodger\nMike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=2, include_payload=True + ) @pytest.mark.asyncio @@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Chunk 0\nChunk 1" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=2, include_payload=True + ) @pytest.mark.asyncio From f35636970849f0c112c14a0f6ccd44f6823d07e2 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 15 Jan 2026 12:00:41 +0100 Subject: [PATCH 05/11] refactor: Update unit tests --- .../tests/unit/modules/retrieval/chunks_retriever_test.py | 8 ++++++-- .../unit/modules/retrieval/summaries_retriever_test.py | 8 ++++++-- .../unit/modules/retrieval/triplet_retriever_test.py | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 98bfd48fe..feb254155 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine): assert len(context) == 2 assert context[0]["text"] == "Steve Rodger" assert context[1]["text"] == "Mike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio @@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) + mock_vector_engine.search.assert_awaited_once_with( + "DocumentChunk_text", "test query", limit=3, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index e552ac74a..7bec8afdf 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine): assert len(context) == 2 assert context[0]["text"] == "S.R." assert context[1]["text"] == "M.B." - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "TextSummary_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio @@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine): context = await retriever.get_context("test query") assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + mock_vector_engine.search.assert_awaited_once_with( + "TextSummary_text", "test query", limit=3, include_payload=True + ) @pytest.mark.asyncio diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index 83612c7aa..e914b0aa4 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine): context = await retriever.get_context("test query") assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) + mock_vector_engine.search.assert_awaited_once_with( + "Triplet_text", "test query", limit=5, include_payload=True + ) @pytest.mark.asyncio From 3635bda6cd6a47edd230d35c40c9364fb3eb23db Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 15 Jan 2026 14:51:15 +0100 Subject: [PATCH 06/11] refactor: Use id instead of payload id for temporal retriever --- cognee/modules/retrieval/temporal_retriever.py | 2 +- cognee/tests/test_kuzu.py | 4 +++- cognee/tests/test_lancedb.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 87d2ab009..cebd03a97 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever): async def filter_top_k_events(self, relevant_events, scored_results): # Build a score lookup from vector search results - score_lookup = {res.payload["id"]: res.score for res in scored_results} + score_lookup = {res.id: res.score for res in scored_results} events_with_scores = [] for event in relevant_events[0]["events"]: diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index fe9da6dcb..63c9a983f 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -70,7 +70,9 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( diff --git a/cognee/tests/test_lancedb.py b/cognee/tests/test_lancedb.py index 115ba99fd..29b149217 100644 --- a/cognee/tests/test_lancedb.py +++ b/cognee/tests/test_lancedb.py @@ -149,7 +149,9 @@ async def main(): await test_getting_of_documents(dataset_name_1) vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = ( + await vector_engine.search("Entity_name", "Quantum computer", include_payload=True) + )[0] random_node_name = random_node.payload["text"] search_results = await cognee.search( From e51149c3a27b0bcd6786d52b067666200ad7a97c Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 15 Jan 2026 17:25:59 +0100 Subject: [PATCH 07/11] refactor: Update temporal tests --- .../retrieval/temporal_retriever_test.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index 1d2f4c84d..a0459b227 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits(): ] scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), + SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10), + SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20), ] top = await tr.filter_top_k_events(relevant_events, scored_results) @@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k ] scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), + SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05), + SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50), ] top = await tr.filter_top_k_events(relevant_events, scored_results) @@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): tr = TemporalRetriever(top_k=10) relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), + SimpleNamespace(id="a", payload={"id": "a"}, score=0.1), + SimpleNamespace(id="b", payload={"id": "b"}, score=0.2), ] out = await tr.filter_top_k_events(relevant_events, scored_results) assert [e["id"] for e in out] == ["a", "b"] @@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine } ] - mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) - mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10) mock_vector_engine.search.return_value = [mock_result1, mock_result2] with ( @@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine) } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] mock_user = MagicMock() @@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( @@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector } ] - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05) mock_vector_engine.search.return_value = [mock_result] with ( From 5d412ed19b9b47baebef32191e0863acd3df6114 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 16 Jan 2026 13:21:42 +0100 Subject: [PATCH 08/11] refactor: Add batch search support for leaving out payload --- .../neptune_analytics/NeptuneAnalyticsAdapter.py | 16 ++++++++++++++-- .../databases/vector/chromadb/ChromaDBAdapter.py | 1 + .../databases/vector/lancedb/LanceDBAdapter.py | 3 +++ .../databases/vector/pgvector/PGVectorAdapter.py | 2 ++ .../databases/vector/vector_db_interface.py | 4 ++++ 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py index a8df0b2e0..9289bb6c8 100644 --- a/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +++ b/cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py @@ -320,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): self._na_exception_handler(e, query_string) async def batch_search( - self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False + self, + collection_name: str, + query_texts: List[str], + limit: int, + with_vectors: bool = False, + include_payload: bool = False, ): """ Perform a batch search using multiple text queries against a collection. @@ -343,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface): data_vectors = await self.embedding_engine.embed_text(query_texts) return await asyncio.gather( *[ - self.search(collection_name, None, vector, limit, with_vectors) + self.search( + collection_name, + None, + vector, + limit, + with_vectors, + include_payload=include_payload, + ) for vector in data_vectors ] ) diff --git a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py index 5e2d3975a..bec97ca94 100644 --- a/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +++ b/cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py @@ -442,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface): query_texts: List[str], limit: int = 5, with_vectors: bool = False, + include_payload: bool = False, ): """ Perform multiple searches in a single request for efficiency, returning results for each diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 49168ffcd..baef75d9e 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -260,6 +260,7 @@ class LanceDBAdapter(VectorDBInterface): .limit(limit) .to_list() ) + if not result_values: return [] normalized_values = normalize_distances(result_values) @@ -279,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface): query_texts: List[str], limit: Optional[int] = None, with_vectors: bool = False, + include_payload: bool = False, ): query_vectors = await self.embedding_engine.embed_text(query_texts) @@ -289,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface): query_vector=query_vector, limit=limit, with_vector=with_vectors, + include_payload=include_payload, ) for query_vector in query_vectors ] diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 6c73ad475..5e2c356ee 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -380,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_texts: List[str], limit: int = None, with_vectors: bool = False, + include_payload: bool = False, ): query_vectors = await self.embedding_engine.embed_text(query_texts) @@ -390,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vector=query_vector, limit=limit, with_vector=with_vectors, + include_payload=include_payload, ) for query_vector in query_vectors ] diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index e38de19e2..4376d8713 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -117,6 +117,7 @@ class VectorDBInterface(Protocol): query_texts: List[str], limit: Optional[int], with_vectors: bool = False, + include_payload: bool = False, ): """ Perform a batch search using multiple text queries against a collection. @@ -129,6 +130,9 @@ class VectorDBInterface(Protocol): - limit (Optional[int]): The maximum number of results to return for each query. - with_vectors (bool): Whether to include vector representations with search results. (default False) + - include_payload (bool): Whether to include the payload data with search. Search is faster when set to False. + Payload contains metadata about the data point, useful for searches that are only based on embedding distances + like the RAG_COMPLETION search type, but not needed when search also contains graph data. """ raise NotImplementedError From 01a638255240151712972cb1092367585b8e4231 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 19 Jan 2026 13:57:30 +0100 Subject: [PATCH 09/11] refactor: Change payload text use to use edge id --- cognee/modules/graph/cognee_graph/CogneeGraph.py | 13 ++++++------- .../graph/cognee_graph/CogneeGraphElements.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index f67c026d3..4c9fd63d7 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,5 +1,6 @@ import time from cognee.shared.logging_utils import get_logger +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any from cognee.modules.graph.exceptions import ( @@ -205,6 +206,10 @@ class CogneeGraph(CogneeAbstractGraph): key: properties.get(key) for key in edge_properties_to_project } edge_attributes["relationship_type"] = relationship_type + edge_text = properties.get("edge_text") or properties.get("relationship_name") + edge_attributes["edge_type_id"] = ( + generate_edge_id(edge_id=edge_text) if edge_text else None + ) edge = Edge( source_node, @@ -284,13 +289,7 @@ class CogneeGraph(CogneeAbstractGraph): for query_index, scored_results in enumerate(per_query_scored_results): for result in scored_results: - payload = getattr(result, "payload", None) - if not isinstance(payload, dict): - continue - text = payload.get("text") - if not text: - continue - matching_edges = self.edges_by_distance_key.get(str(text)) + matching_edges = self.edges_by_distance_key.get(str(result.id)) if not matching_edges: continue for edge in matching_edges: diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index c9226b6a1..e8e06920d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -141,7 +141,7 @@ class Edge: self.status = np.ones(dimension, dtype=int) def get_distance_key(self) -> Optional[str]: - key = self.attributes.get("edge_text") or self.attributes.get("relationship_type") + key = self.attributes.get("edge_type_id") if key is None: return None return str(key) From 099d78ccfcf80c41c2f0f9aee4d954533922dd51 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 19 Jan 2026 15:55:27 +0100 Subject: [PATCH 10/11] refactor: add edge_type_id generation in add_edge instead of graph projection --- .../modules/graph/cognee_graph/CogneeGraph.py | 10 ++++--- .../unit/modules/graph/cognee_graph_test.py | 27 ++++++++++++++----- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4c9fd63d7..aad6ad858 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -45,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph): def add_edge(self, edge: Edge) -> None: self.edges.append(edge) + + edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + edge.attributes["edge_type_id"] = ( + generate_edge_id(edge_id=edge_text) if edge_text else None + ) # Update edge with generated edge_type_id + edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) key = edge.get_distance_key() @@ -206,10 +212,6 @@ class CogneeGraph(CogneeAbstractGraph): key: properties.get(key) for key in edge_properties_to_project } edge_attributes["relationship_type"] = relationship_type - edge_text = properties.get("edge_text") or properties.get("relationship_name") - edge_attributes["edge_type_id"] = ( - generate_edge_id(edge_id=edge_text) if edge_text else None - ) edge = Edge( source_node, diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index a13031ac5..5e40ce3a6 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import AsyncMock +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node @@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): graph.add_edge(edge) edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph): graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "CONNECTS_TO" edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr ) graph.add_edge(edge) + edge_text = "KNOWS" edge_distances = [ - MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): ) graph.add_edge(edge) + edge_text = "SOME_OTHER_EDGE" edge_distances = [ - MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}), ] await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) @@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "A" + edge_2_text = "B" edge_distances = [ - [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0 - [MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1 + [ + MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text}) + ], # query 0 + [ + MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text}) + ], # query 1 ] await graph.map_vector_distances_to_graph_edges( @@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se graph.add_edge(edge1) graph.add_edge(edge2) + edge_1_text = "A" edge_distances = [ - [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped + [ + MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text}) + ], # query 0: only edge1 mapped [], # query 1: no edges mapped ] From 6b8ff648adf41b43fba247cd70ff7088830a2000 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 19 Jan 2026 16:03:40 +0100 Subject: [PATCH 11/11] refactor: Update tests with edge change --- .../modules/retrieval/test_brute_force_triplet_search.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index fcbfd2434..4f41f9e3d 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import ( get_memory_fragment, format_triplets, ) +from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError @@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes(): ] } + edge_1_text = "relates_to" + edge_2_text = "relates_to" edge_distances_batch = [ - [MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})], - [MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})], + [MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})], + [MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})], ] await graph.map_vector_distances_to_graph_nodes(