feature: cover current context structure with unit test and add time logging to vector collection retrievals (#1144)

<!-- .github/pull_request_template.md -->

## Description
Cover current context structure with unit test so it is not changed
accidentally in the future

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
hajdul88 2025-07-25 13:04:43 +02:00 committed by GitHub
parent 7f972d3ab5
commit 9157d3c2dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 99 additions and 9 deletions

View file

@ -1,3 +1,4 @@
import time
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union, Optional, Type
@ -154,11 +155,16 @@ class CogneeGraph(CogneeAbstractGraph):
raise ValueError("Failed to generate query embedding.")
if edge_distances is None:
start_time = time.time()
edge_distances = await vector_engine.search(
collection_name="EdgeType_relationship_name",
query_vector=query_vector,
limit=0,
)
projection_time = time.time() - start_time
logger.info(
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
)
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

View file

@ -1,4 +1,5 @@
import asyncio
import time
from typing import List, Optional, Type
from cognee.shared.logging_utils import get_logger, ERROR
@ -174,6 +175,8 @@ async def brute_force_search(
return []
try:
start_time = time.time()
results = await asyncio.gather(
*[search_in_collection(collection_name) for collection_name in collections]
)
@ -181,6 +184,12 @@ async def brute_force_search(
if all(not item for item in results):
return []
# Final statistics
projection_time = time.time() - start_time
logger.info(
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)

View file

@ -28,18 +28,38 @@ class TestGraphCompletionRetriever:
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
company1 = Company(name="Figma", description="Figma is a company")
company2 = Company(name="Canva", description="Canvas is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
person2 = Person(
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
)
person3 = Person(
name="Jason Statham",
description="This is description about Jason Statham",
works_for=company1,
)
person4 = Person(
name="Mike Broski",
description="This is description about Mike Broski",
works_for=company2,
)
person5 = Person(
name="Christina Mayer",
description="This is description about Christina Mayer",
works_for=company2,
)
entities = [company1, company2, person1, person2, person3, person4, person5]
@ -49,8 +69,63 @@ class TestGraphCompletionRetriever:
context = await retriever.get_context("Who works at Canva?")
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
assert "Connections:" in context, "Missing 'Connections:' section in context"
# --- Nodes headers ---
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
assert "Node: Figma" in context, "Missing node header for Figma"
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
assert "Node: Canva" in context, "Missing node header for Canva"
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
# --- Node contents ---
assert (
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
in context
), "Description block for Steve Rodger altered"
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
"Description block for Figma altered"
)
assert (
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
in context
), "Description block for Ike Loma altered"
assert (
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
in context
), "Description block for Jason Statham altered"
assert (
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
in context
), "Description block for Mike Broski altered"
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
"Description block for Canva altered"
)
assert (
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
in context
), "Description block for Christina Mayer altered"
# --- Connections ---
assert "Steve Rodger --[works_for]--> Figma" in context, (
"Connection Steve Rodger→Figma missing or changed"
)
assert "Ike Loma --[works_for]--> Figma" in context, (
"Connection Ike Loma→Figma missing or changed"
)
assert "Jason Statham --[works_for]--> Figma" in context, (
"Connection Jason Statham→Figma missing or changed"
)
assert "Mike Broski --[works_for]--> Canva" in context, (
"Connection Mike Broski→Canva missing or changed"
)
assert "Christina Mayer --[works_for]--> Canva" in context, (
"Connection Christina Mayer→Canva missing or changed"
)
@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):