From 9157d3c2ddc6465a345dfa4cf162a1e969289187 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 25 Jul 2025 13:04:43 +0200 Subject: [PATCH] feature: cover current context structure with unit test and add time logging to vector collection retrievals (#1144) ## Description Cover current context structure with unit test so it is not changed accidentally in the future ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../modules/graph/cognee_graph/CogneeGraph.py | 6 ++ .../utils/brute_force_triplet_search.py | 9 ++ .../graph_completion_retriever_test.py | 93 +++++++++++++++++-- 3 files changed, 99 insertions(+), 9 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 7784e8a64..4e3a2d15a 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,3 +1,4 @@ +import time from cognee.shared.logging_utils import get_logger from typing import List, Dict, Union, Optional, Type @@ -154,11 +155,16 @@ class CogneeGraph(CogneeAbstractGraph): raise ValueError("Failed to generate query embedding.") if edge_distances is None: + start_time = time.time() edge_distances = await vector_engine.search( collection_name="EdgeType_relationship_name", query_vector=query_vector, limit=0, ) + projection_time = time.time() - start_time + logger.info( + f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s" + ) embedding_map = {result.payload["text"]: result.score for result in edge_distances} diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 49f4508a0..44bb10dcb 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR @@ -174,6 +175,8 @@ async def brute_force_search( return [] try: + start_time = time.time() + results = await asyncio.gather( *[search_in_collection(collection_name) for collection_name in collections] ) @@ -181,6 +184,12 @@ async def brute_force_search( if all(not item for item in results): return [] + # Final statistics + projection_time = time.time() - start_time + logger.info( + f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s" + ) + node_distances = {collection: result for collection, result in zip(collections, results)} edge_distances = node_distances.get("EdgeType_relationship_name", None) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 27094fbf1..976b69e69 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -28,18 +28,38 @@ class TestGraphCompletionRetriever: class Company(DataPoint): name: str + description: str class Person(DataPoint): name: str + description: str works_for: Company - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) + company1 = Company(name="Figma", description="Figma is a company") + company2 = Company(name="Canva", description="Canvas is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + person2 = Person( + name="Ike Loma", description="This is description about Ike Loma", works_for=company1 + ) + person3 = Person( + name="Jason Statham", + description="This is description about Jason Statham", + works_for=company1, + ) + person4 = Person( + name="Mike Broski", + description="This is description about Mike Broski", + works_for=company2, + ) + person5 = Person( + name="Christina Mayer", + description="This is description about Christina Mayer", + works_for=company2, + ) entities = [company1, company2, person1, person2, person3, person4, person5] @@ -49,8 +69,63 @@ class TestGraphCompletionRetriever: context = await retriever.get_context("Who works at Canva?") - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + # Ensure the top-level sections are present + assert "Nodes:" in context, "Missing 'Nodes:' section in context" + assert "Connections:" in context, "Missing 'Connections:' section in context" + + # --- Nodes headers --- + assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" + assert "Node: Figma" in context, "Missing node header for Figma" + assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" + assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" + assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" + assert "Node: Canva" in context, "Missing node header for Canva" + assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" + + # --- Node contents --- + assert ( + "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" + in context + ), "Description block for Steve Rodger altered" + assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( + "Description block for Figma altered" + ) + assert ( + "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" + in context + ), "Description block for Ike Loma altered" + assert ( + "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" + in context + ), "Description block for Jason Statham altered" + assert ( + "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" + in context + ), "Description block for Mike Broski altered" + assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( + "Description block for Canva altered" + ) + assert ( + "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" + in context + ), "Description block for Christina Mayer altered" + + # --- Connections --- + assert "Steve Rodger --[works_for]--> Figma" in context, ( + "Connection Steve Rodger→Figma missing or changed" + ) + assert "Ike Loma --[works_for]--> Figma" in context, ( + "Connection Ike Loma→Figma missing or changed" + ) + assert "Jason Statham --[works_for]--> Figma" in context, ( + "Connection Jason Statham→Figma missing or changed" + ) + assert "Mike Broski --[works_for]--> Canva" in context, ( + "Connection Mike Broski→Canva missing or changed" + ) + assert "Christina Mayer --[works_for]--> Canva" in context, ( + "Connection Christina Mayer→Canva missing or changed" + ) @pytest.mark.asyncio async def test_graph_completion_context_complex(self):