feature: cover current context structure with unit test and add time logging to vector collection retrievals (#1144)
<!-- .github/pull_request_template.md --> ## Description Cover current context structure with unit test so it is not changed accidentally in the future ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
7f972d3ab5
commit
9157d3c2dd
3 changed files with 99 additions and 9 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import time
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from typing import List, Dict, Union, Optional, Type
|
from typing import List, Dict, Union, Optional, Type
|
||||||
|
|
||||||
|
|
@ -154,11 +155,16 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
raise ValueError("Failed to generate query embedding.")
|
raise ValueError("Failed to generate query embedding.")
|
||||||
|
|
||||||
if edge_distances is None:
|
if edge_distances is None:
|
||||||
|
start_time = time.time()
|
||||||
edge_distances = await vector_engine.search(
|
edge_distances = await vector_engine.search(
|
||||||
collection_name="EdgeType_relationship_name",
|
collection_name="EdgeType_relationship_name",
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=0,
|
limit=0,
|
||||||
)
|
)
|
||||||
|
projection_time = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
|
|
@ -174,6 +175,8 @@ async def brute_force_search(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[search_in_collection(collection_name) for collection_name in collections]
|
*[search_in_collection(collection_name) for collection_name in collections]
|
||||||
)
|
)
|
||||||
|
|
@ -181,6 +184,12 @@ async def brute_force_search(
|
||||||
if all(not item for item in results):
|
if all(not item for item in results):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Final statistics
|
||||||
|
projection_time = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||||
|
|
||||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||||
|
|
|
||||||
|
|
@ -28,18 +28,38 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
class Company(DataPoint):
|
class Company(DataPoint):
|
||||||
name: str
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
class Person(DataPoint):
|
class Person(DataPoint):
|
||||||
name: str
|
name: str
|
||||||
|
description: str
|
||||||
works_for: Company
|
works_for: Company
|
||||||
|
|
||||||
company1 = Company(name="Figma")
|
company1 = Company(name="Figma", description="Figma is a company")
|
||||||
company2 = Company(name="Canva")
|
company2 = Company(name="Canva", description="Canvas is a company")
|
||||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
person1 = Person(
|
||||||
person2 = Person(name="Ike Loma", works_for=company1)
|
name="Steve Rodger",
|
||||||
person3 = Person(name="Jason Statham", works_for=company1)
|
description="This is description about Steve Rodger",
|
||||||
person4 = Person(name="Mike Broski", works_for=company2)
|
works_for=company1,
|
||||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
)
|
||||||
|
person2 = Person(
|
||||||
|
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||||
|
)
|
||||||
|
person3 = Person(
|
||||||
|
name="Jason Statham",
|
||||||
|
description="This is description about Jason Statham",
|
||||||
|
works_for=company1,
|
||||||
|
)
|
||||||
|
person4 = Person(
|
||||||
|
name="Mike Broski",
|
||||||
|
description="This is description about Mike Broski",
|
||||||
|
works_for=company2,
|
||||||
|
)
|
||||||
|
person5 = Person(
|
||||||
|
name="Christina Mayer",
|
||||||
|
description="This is description about Christina Mayer",
|
||||||
|
works_for=company2,
|
||||||
|
)
|
||||||
|
|
||||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||||
|
|
||||||
|
|
@ -49,8 +69,63 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
context = await retriever.get_context("Who works at Canva?")
|
context = await retriever.get_context("Who works at Canva?")
|
||||||
|
|
||||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
# Ensure the top-level sections are present
|
||||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||||
|
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||||
|
|
||||||
|
# --- Nodes headers ---
|
||||||
|
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||||
|
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||||
|
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||||
|
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||||
|
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||||
|
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||||
|
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||||
|
|
||||||
|
# --- Node contents ---
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Steve Rodger altered"
|
||||||
|
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||||
|
"Description block for Figma altered"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Ike Loma altered"
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Jason Statham altered"
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Mike Broski altered"
|
||||||
|
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||||
|
"Description block for Canva altered"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||||
|
in context
|
||||||
|
), "Description block for Christina Mayer altered"
|
||||||
|
|
||||||
|
# --- Connections ---
|
||||||
|
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Steve Rodger→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Ike Loma→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||||
|
"Connection Jason Statham→Figma missing or changed"
|
||||||
|
)
|
||||||
|
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||||
|
"Connection Mike Broski→Canva missing or changed"
|
||||||
|
)
|
||||||
|
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||||
|
"Connection Christina Mayer→Canva missing or changed"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_context_complex(self):
|
async def test_graph_completion_context_complex(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue