chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881)

<!-- .github/pull_request_template.md -->

## Description
This PR restructures/adds integration and unit tests for the retrieval
module.

-Old integration tests were updated and moved under unit tests +
fixtures added
-Added missing unit tests for all core retrieval business logic
-Covered 100% of the core retrievers with tests
-Minor changes (dead code deletion, typo fixed)

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Changes**
* TripletRetriever now returns up to 5 results by default (was 1),
providing richer context.

* **Tests**
* Reorganized test coverage: many unit tests removed and replaced with
comprehensive integration tests across retrieval components (graph,
chunks, RAG, summaries, temporal, triplets, structured output).

* **Chores**
  * Simplified triplet formatting logic and removed debug output.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
hajdul88 2025-12-16 11:11:29 +01:00 committed by GitHub
parent 78028b819f
commit 4e8845c117
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1888 additions and 2303 deletions

View file

@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.top_k = top_k if top_k is not None else 5
self.system_prompt = system_prompt
async def get_context(self, query: str) -> str:

View file

@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {
k: v for k, v in value.items() if k in attributes and v is not None
}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1

View file

@ -0,0 +1,252 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import List
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_simple():
"""Set up a clean test environment with simple chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_complex():
"""Set up a clean test environment with complex chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
"""Integration test: verify ChunksRetriever can retrieve multiple chunks."""
retriever = ChunksRetriever()
context = await retriever.get_context("Steve")
assert isinstance(context, list), "Context should be a list"
assert len(context) > 0, "Context should not be empty"
assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
"Failed to get Steve Rodger chunk"
)
@pytest.mark.asyncio
async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
"""Integration test: verify ChunksRetriever respects top_k parameter."""
retriever = ChunksRetriever(top_k=2)
context = await retriever.get_context("Employee")
assert isinstance(context, list), "Context should be a list"
assert len(context) <= 2, "Should respect top_k limit"
@pytest.mark.asyncio
async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
"""Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
retriever = ChunksRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify ChunksRetriever handles empty graph correctly."""
retriever = ChunksRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"

View file

@ -0,0 +1,268 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma", description="Figma is a company")
company2 = Company(name="Canva", description="Canvas is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
person2 = Person(
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
)
person3 = Person(
name="Jason Statham",
description="This is description about Jason Statham",
works_for=company1,
)
person4 = Person(
name="Mike Broski",
description="This is description about Mike Broski",
works_for=company2,
)
person5 = Person(
name="Christina Mayer",
description="This is description about Christina Mayer",
works_for=company2,
)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionRetriever can retrieve context (simple)."""
retriever = GraphCompletionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
assert "Connections:" in context, "Missing 'Connections:' section in context"
# --- Nodes headers ---
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
assert "Node: Figma" in context, "Missing node header for Figma"
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
assert "Node: Canva" in context, "Missing node header for Canva"
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
# --- Node contents ---
assert (
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
in context
), "Description block for Steve Rodger altered"
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
"Description block for Figma altered"
)
assert (
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
in context
), "Description block for Ike Loma altered"
assert (
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
in context
), "Description block for Jason Statham altered"
assert (
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
in context
), "Description block for Mike Broski altered"
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
"Description block for Canva altered"
)
assert (
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
in context
), "Description block for Christina Mayer altered"
# --- Connections ---
assert "Steve Rodger --[works_for]--> Figma" in context, (
"Connection Steve Rodger→Figma missing or changed"
)
assert "Ike Loma --[works_for]--> Figma" in context, (
"Connection Ike Loma→Figma missing or changed"
)
assert "Jason Statham --[works_for]--> Figma" in context, (
"Connection Jason Statham→Figma missing or changed"
)
assert "Mike Broski --[works_for]--> Canva" in context, (
"Connection Mike Broski→Canva missing or changed"
)
assert "Christina Mayer --[works_for]--> Canva" in context, (
"Connection Christina Mayer→Canva missing or changed"
)
@pytest.mark.asyncio
async def test_graph_completion_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionRetriever can retrieve context (complex)."""
retriever = GraphCompletionRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionRetriever handles empty graph correctly."""
retriever = GraphCompletionRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
@pytest.mark.asyncio
async def test_graph_completion_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionRetriever get_triplets handles empty graph."""
retriever = GraphCompletionRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,226 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_extension_context_simple"
)
data_directory_path = str(
base_dir / ".data_storage/test_graph_completion_extension_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_extension_context_complex"
)
data_directory_path = str(
base_dir / ".data_storage/test_graph_completion_extension_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple)."""
retriever = GraphCompletionContextExtensionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex)."""
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context = await resolve_edges_to_text(
await retriever.get_context("Who works at Figma and drives Tesla?")
)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly."""
retriever = GraphCompletionContextExtensionRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph."""
retriever = GraphCompletionContextExtensionRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,218 @@
import os
import pytest
import pathlib
import pytest_asyncio
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
@pytest_asyncio.fixture
async def setup_test_environment_simple():
"""Set up a clean test environment with simple graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_cot_context_simple"
)
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_complex():
"""Set up a clean test environment with complex graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_graph_completion_cot_context_complex"
)
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without graph data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(setup_test_environment_simple):
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (simple)."""
retriever = GraphCompletionCotRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(setup_test_environment_complex):
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (complex)."""
retriever = GraphCompletionCotRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify GraphCompletionCotRetriever handles empty graph correctly."""
retriever = GraphCompletionCotRetriever()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty):
"""Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph."""
retriever = GraphCompletionCotRetriever()
triplets = await retriever.get_triplets("Who works at Figma?")
assert isinstance(triplets, list), "Triplets should be a list"
assert len(triplets) == 0, "Should return empty list on empty graph"

View file

@ -0,0 +1,254 @@
import os
from typing import List
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_simple():
"""Set up a clean test environment with simple chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_chunks_complex():
"""Set up a clean test environment with complex chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex")
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without chunks."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph"
)
data_directory_path = str(
base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple):
"""Integration test: verify CompletionRetriever can retrieve context (simple)."""
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert isinstance(context, str), "Context should be a string"
assert "Mike Broski" in context, "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple):
"""Integration test: verify CompletionRetriever can retrieve context from multiple chunks."""
retriever = CompletionRetriever()
context = await retriever.get_context("Steve")
assert isinstance(context, str), "Context should be a string"
assert "Steve Rodger" in context, "Failed to get Steve Rodger"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex):
"""Integration test: verify CompletionRetriever can retrieve context (complex)."""
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify CompletionRetriever handles empty graph correctly."""
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"

View file

@ -1,9 +1,9 @@
import asyncio
import pytest
import cognee
import pathlib
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from pydantic import BaseModel
from cognee.low_level import setup, DataPoint
@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion():
_assert_structured_answer(structured_answer)
class TestStructuredOutputCompletion:
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
@pytest_asyncio.fixture
async def setup_test_environment():
"""Set up a clean test environment with graph and document data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion")
data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
except Exception:
pass
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
works_since: int
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
entities = [company1, person1]
await add_data_points(entities)
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
entity_type = EntityType(name="Person", description="A human individual")
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
entities = [entity]
await add_data_points(entities)
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()
@pytest.mark.asyncio
async def test_get_structured_completion(setup_test_environment):
"""Integration test: verify structured output completion for all retrievers."""
await _test_get_structured_graph_completion_cot()
await _test_get_structured_graph_completion()
await _test_get_structured_graph_completion_temporal()
await _test_get_structured_graph_completion_rag()
await _test_get_structured_graph_completion_context_extension()
await _test_get_structured_entity_completion()

View file

@ -0,0 +1,184 @@
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
@pytest_asyncio.fixture
async def setup_test_environment_with_summaries():
"""Set up a clean test environment with summaries."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context")
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without summaries."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty")
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_summaries_retriever_context(setup_test_environment_with_summaries):
"""Integration test: verify SummariesRetriever can retrieve summary context."""
retriever = SummariesRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert isinstance(context, list), "Context should be a list"
assert len(context) > 0, "Context should not be empty"
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty):
"""Integration test: verify SummariesRetriever handles empty graph correctly."""
retriever = SummariesRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"

View file

@ -0,0 +1,306 @@
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.engine.models.Event import Event
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.modules.engine.models.Interval import Interval
@pytest_asyncio.fixture
async def setup_test_environment_with_events():
"""Set up a clean test environment with temporal events."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
# Create timestamps for events
timestamp1 = Timestamp(
time_at=1609459200, # 2021-01-01 00:00:00
year=2021,
month=1,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-01-01T00:00:00",
)
timestamp2 = Timestamp(
time_at=1612137600, # 2021-02-01 00:00:00
year=2021,
month=2,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-02-01T00:00:00",
)
timestamp3 = Timestamp(
time_at=1614556800, # 2021-03-01 00:00:00
year=2021,
month=3,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-03-01T00:00:00",
)
timestamp4 = Timestamp(
time_at=1625097600, # 2021-07-01 00:00:00
year=2021,
month=7,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-07-01T00:00:00",
)
timestamp5 = Timestamp(
time_at=1633046400, # 2021-10-01 00:00:00
year=2021,
month=10,
day=1,
hour=0,
minute=0,
second=0,
timestamp_str="2021-10-01T00:00:00",
)
# Create interval for event spanning multiple timestamps
interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
# Create events with timestamps
event1 = Event(
name="Project Alpha Launch",
description="Launched Project Alpha at the beginning of 2021",
at=timestamp1,
location="San Francisco",
)
event2 = Event(
name="Team Meeting",
description="Monthly team meeting discussing Q1 goals",
during=interval1,
location="New York",
)
event3 = Event(
name="Product Release",
description="Released new product features in July",
at=timestamp4,
location="Remote",
)
event4 = Event(
name="Company Retreat",
description="Annual company retreat in October",
at=timestamp5,
location="Lake Tahoe",
)
entities = [event1, event2, event3, event4]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_with_graph_data():
"""Set up a clean test environment with graph data (for fallback to triplets)."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma", description="Figma is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
entities = [company1, person1]
await add_data_points(entities)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without data."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve events within time range."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("What happened in January 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Project Alpha" in context or "Launch" in context, (
"Should retrieve Project Alpha Launch event from January 2021"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve events at specific time."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("What happened in July 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Product Release" in context or "July" in context, (
"Should retrieve Product Release event from July 2021"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_fallback_to_triplets(
setup_test_environment_with_graph_data,
):
"""Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
retriever = TemporalRetriever(top_k=5)
context = await retriever.get_context("Who works at Figma?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert "Steve" in context or "Figma" in context, (
"Should retrieve graph data via triplet search fallback"
)
@pytest.mark.asyncio
async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
"""Integration test: verify TemporalRetriever handles empty graph correctly."""
retriever = TemporalRetriever()
context = await retriever.get_context("What happened?")
assert isinstance(context, str), "Context should be a string"
assert len(context) >= 0, "Context should be a string (possibly empty)"
@pytest.mark.asyncio
async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can generate completions."""
retriever = TemporalRetriever()
completion = await retriever.get_completion("What happened in January 2021?")
assert isinstance(completion, list), "Completion should be a list"
assert len(completion) > 0, "Completion should not be empty"
assert all(isinstance(item, str) and item.strip() for item in completion), (
"Completion items should be non-empty strings"
)
@pytest.mark.asyncio
async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
"""Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
retriever = TemporalRetriever()
completion = await retriever.get_completion("Who works at Figma?")
assert isinstance(completion, list), "Completion should be a list"
assert len(completion) > 0, "Completion should not be empty"
assert all(isinstance(item, str) and item.strip() for item in completion), (
"Completion items should be non-empty strings"
)
@pytest.mark.asyncio
async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever respects top_k parameter."""
retriever = TemporalRetriever(top_k=2)
context = await retriever.get_context("What happened in 2021?")
assert isinstance(context, str), "Context should be a string"
separator_count = context.count("#####################")
assert separator_count <= 1, "Should respect top_k limit of 2 events"
@pytest.mark.asyncio
async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
"""Integration test: verify TemporalRetriever can retrieve multiple events."""
retriever = TemporalRetriever(top_k=10)
context = await retriever.get_context("What events occurred in 2021?")
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
assert (
"Project Alpha" in context
or "Team Meeting" in context
or "Product Release" in context
or "Company Retreat" in context
), "Should retrieve at least one event from 2021"

View file

@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip
context = await retriever.get_context("Alice")
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
assert isinstance(context, str), "Context should be a string"
assert len(context) > 0, "Context should not be empty"
@pytest.mark.asyncio
async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets):
"""Integration test: verify TripletRetriever can retrieve multiple triplets."""
retriever = TripletRetriever(top_k=5)
context = await retriever.get_context("Bob")
assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, (
"Failed to get Bob-related triplets"
)
@pytest.mark.asyncio
async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets):
"""Integration test: verify TripletRetriever respects top_k parameter."""
retriever = TripletRetriever(top_k=1)
context = await retriever.get_context("Alice")
assert isinstance(context, str), "Context should be a string"
@pytest.mark.asyncio
async def test_triplet_retriever_context_empty(setup_test_environment_empty):
"""Integration test: verify TripletRetriever handles empty graph correctly."""
await setup()
retriever = TripletRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Alice")

View file

@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
"""
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
ADAPTER_CLASSES = [
HotpotQAAdapter,
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
adapter = AdapterClass()
result = adapter.load_corpus()
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
result = adapter.load_corpus()
else:
adapter = AdapterClass()
result = adapter.load_corpus()
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
else:
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)

View file

@ -2,15 +2,38 @@ import pytest
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
from cognee.infrastructure.databases.graph import get_graph_engine
from unittest.mock import AsyncMock, patch
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
@pytest.mark.parametrize("benchmark", benchmark_options)
def test_corpus_builder_load_corpus(benchmark):
limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
)

View file

@ -1,201 +0,0 @@
import os
import pytest
import pathlib
from typing import List
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
class TestChunksRetriever:
@pytest.mark.asyncio
async def test_chunk_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = ChunksRetriever()
context = await retriever.get_context("Mike")
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_chunk_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
retriever = ChunksRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = ChunksRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"

View file

@ -1,154 +0,0 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.context_global_variables import session_user
import importlib
def create_mock_cache_engine(qa_history=None):
mock_cache = AsyncMock()
if qa_history is None:
qa_history = []
mock_cache.get_latest_qa = AsyncMock(return_value=qa_history)
mock_cache.add_qa = AsyncMock(return_value=None)
return mock_cache
def create_mock_user():
mock_user = MagicMock()
mock_user.id = "test-user-id-123"
return mock_user
class TestConversationHistoryUtils:
@pytest.mark.asyncio
async def test_get_conversation_history_returns_empty_when_no_history(self):
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
from cognee.modules.retrieval.utils.session_cache import get_conversation_history
result = await get_conversation_history(session_id="test_session")
assert result == ""
@pytest.mark.asyncio
async def test_get_conversation_history_formats_history_correctly(self):
"""Test get_conversation_history formats Q&A history with correct structure."""
user = create_mock_user()
session_user.set(user)
mock_history = [
{
"time": "2024-01-15 10:30:45",
"question": "What is AI?",
"context": "AI is artificial intelligence",
"answer": "AI stands for Artificial Intelligence",
}
]
mock_cache = create_mock_cache_engine(mock_history)
# Import the real module to patch safely
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)
result = await get_conversation_history(session_id="test_session")
assert "Previous conversation:" in result
assert "[2024-01-15 10:30:45]" in result
assert "QUESTION: What is AI?" in result
assert "CONTEXT: AI is artificial intelligence" in result
assert "ANSWER: AI stands for Artificial Intelligence" in result
@pytest.mark.asyncio
async def test_save_to_session_cache_saves_correctly(self):
"""Test save_conversation_history calls add_qa with correct parameters."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="What is Python?",
context_summary="Python is a programming language",
answer="Python is a high-level programming language",
session_id="my_session",
)
assert result is True
mock_cache.add_qa.assert_called_once()
call_kwargs = mock_cache.add_qa.call_args.kwargs
assert call_kwargs["question"] == "What is Python?"
assert call_kwargs["context"] == "Python is a programming language"
assert call_kwargs["answer"] == "Python is a high-level programming language"
assert call_kwargs["session_id"] == "my_session"
@pytest.mark.asyncio
async def test_save_to_session_cache_uses_default_session_when_none(self):
"""Test save_conversation_history uses 'default_session' when session_id is None."""
user = create_mock_user()
session_user.set(user)
mock_cache = create_mock_cache_engine([])
cache_module = importlib.import_module(
"cognee.infrastructure.databases.cache.get_cache_engine"
)
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
with patch(
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
) as MockCacheConfig:
mock_config = MagicMock()
mock_config.caching = True
MockCacheConfig.return_value = mock_config
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
)
result = await save_conversation_history(
query="Test question",
context_summary="Test context",
answer="Test answer",
session_id=None,
)
assert result is True
call_kwargs = mock_cache.add_qa.call_args.kwargs
assert call_kwargs["session_id"] == "default_session"

View file

@ -1,177 +0,0 @@
import os
import pytest
import pathlib
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
class TestGraphCompletionWithContextExtensionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_simple",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_simple",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionContextExtensionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_complex",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context = await resolve_edges_to_text(
await retriever.get_context("Who works at Figma and drives Tesla?")
)
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionContextExtensionRetriever()
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

View file

@ -1,170 +0,0 @@
import os
import pytest
import pathlib
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_cot_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionCotRetriever()
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

View file

@ -1,223 +0,0 @@
import os
import pytest
import pathlib
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
description: str
class Person(DataPoint):
name: str
description: str
works_for: Company
company1 = Company(name="Figma", description="Figma is a company")
company2 = Company(name="Canva", description="Canvas is a company")
person1 = Person(
name="Steve Rodger",
description="This is description about Steve Rodger",
works_for=company1,
)
person2 = Person(
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
)
person3 = Person(
name="Jason Statham",
description="This is description about Jason Statham",
works_for=company1,
)
person4 = Person(
name="Mike Broski",
description="This is description about Mike Broski",
works_for=company2,
)
person5 = Person(
name="Christina Mayer",
description="This is description about Christina Mayer",
works_for=company2,
)
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionRetriever()
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
assert "Connections:" in context, "Missing 'Connections:' section in context"
# --- Nodes headers ---
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
assert "Node: Figma" in context, "Missing node header for Figma"
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
assert "Node: Canva" in context, "Missing node header for Canva"
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
# --- Node contents ---
assert (
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
in context
), "Description block for Steve Rodger altered"
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
"Description block for Figma altered"
)
assert (
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
in context
), "Description block for Ike Loma altered"
assert (
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
in context
), "Description block for Jason Statham altered"
assert (
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
in context
), "Description block for Mike Broski altered"
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
"Description block for Canva altered"
)
assert (
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
in context
), "Description block for Christina Mayer altered"
# --- Connections ---
assert "Steve Rodger --[works_for]--> Figma" in context, (
"Connection Steve Rodger→Figma missing or changed"
)
assert "Ike Loma --[works_for]--> Figma" in context, (
"Connection Ike Loma→Figma missing or changed"
)
assert "Jason Statham --[works_for]--> Figma" in context, (
"Connection Jason Statham→Figma missing or changed"
)
assert "Mike Broski --[works_for]--> Canva" in context, (
"Connection Mike Broski→Canva missing or changed"
)
assert "Christina Mayer --[works_for]--> Canva" in context, (
"Connection Christina Mayer→Canva missing or changed"
)
@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
entities = [company1, company2, person1, person2, person3, person4, person5]
await add_data_points(entities)
retriever = GraphCompletionRetriever(top_k=20)
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionRetriever()
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"

View file

@ -1,205 +0,0 @@
import os
from typing import List
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunkWithEntities(DataPoint):
text: str
chunk_size: int
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
metadata: dict = {"index_fields": ["text"]}
class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert context == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection(
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"

View file

@ -1,159 +0,0 @@
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TestSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
await add_data_points(entities)
retriever = SummariesRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = SummariesRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"

View file

@ -1,224 +0,0 @@
from types import SimpleNamespace
import pytest
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
# Test TemporalRetriever initialization defaults and overrides
def test_init_defaults_and_overrides():
tr = TemporalRetriever()
assert tr.top_k == 5
assert tr.user_prompt_path == "graph_context_for_question.txt"
assert tr.system_prompt_path == "answer_simple_question.txt"
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
tr2 = TemporalRetriever(
top_k=3,
user_prompt_path="u.txt",
system_prompt_path="s.txt",
time_extraction_prompt_path="t.txt",
)
assert tr2.top_k == 3
assert tr2.user_prompt_path == "u.txt"
assert tr2.system_prompt_path == "s.txt"
assert tr2.time_extraction_prompt_path == "t.txt"
# Test descriptions_to_string with basic and empty results
def test_descriptions_to_string_basic_and_empty():
tr = TemporalRetriever()
results = [
{"description": " First "},
{"nope": "no description"},
{"description": "Second"},
{"description": ""},
{"description": " Third line "},
]
s = tr.descriptions_to_string(results)
assert s == "First\n#####################\nSecond\n#####################\nThird line"
assert tr.descriptions_to_string([]) == ""
# Test filter_top_k_events sorts and limits correctly
@pytest.mark.asyncio
async def test_filter_top_k_events_sorts_and_limits():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3 - not in vector results"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["e2", "e1"]
assert all("score" in e for e in top)
assert top[0]["score"] == 0.10
assert top[1]["score"] == 0.20
# Test filter_top_k_events handles unknown ids as infinite scores
@pytest.mark.asyncio
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
tr = TemporalRetriever(top_k=2)
relevant_events = [
{
"events": [
{"id": "known1", "description": "Known 1"},
{"id": "unknown", "description": "Unknown"},
{"id": "known2", "description": "Known 2"},
]
}
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in top] == ["known2", "known1"]
assert all(e["score"] != float("inf") for e in top)
# Test descriptions_to_string with unicode and newlines
def test_descriptions_to_string_unicode_and_newlines():
tr = TemporalRetriever()
results = [
{"description": "Line A\nwith newline"},
{"description": "This is a description"},
]
s = tr.descriptions_to_string(results)
assert "Line A\nwith newline" in s
assert "This is a description" in s
assert s.count("#####################") == 1
# Test filter_top_k_events when top_k is larger than available events
@pytest.mark.asyncio
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
# Test filter_top_k_events when scored_results is empty
@pytest.mark.asyncio
async def test_filter_top_k_events_handles_empty_scored_results():
tr = TemporalRetriever(top_k=2)
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
scored_results = []
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["x", "y"]
assert all(e["score"] == float("inf") for e in out)
# Test filter_top_k_events error handling for missing structure
@pytest.mark.asyncio
async def test_filter_top_k_events_error_handling():
tr = TemporalRetriever(top_k=2)
with pytest.raises((KeyError, TypeError)):
await tr.filter_top_k_events([{}], [])
class _FakeRetriever(TemporalRetriever):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._calls = []
async def extract_time_from_query(self, query: str):
if "both" in query:
return "2024-01-01", "2024-12-31"
if "from_only" in query:
return "2024-01-01", None
if "to_only" in query:
return None, "2024-12-31"
return None, None
async def get_triplets(self, query: str):
self._calls.append(("get_triplets", query))
return [{"s": "a", "p": "b", "o": "c"}]
async def resolve_edges_to_text(self, triplets):
self._calls.append(("resolve_edges_to_text", len(triplets)))
return "edges->text"
async def _fake_graph_collect_ids(self, **kwargs):
return ["e1", "e2"]
async def _fake_graph_collect_events(self, ids):
return [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3"},
]
}
]
async def _fake_vector_embed(self, texts):
assert isinstance(texts, list) and texts
return [[0.0, 1.0, 2.0]]
async def _fake_vector_search(self, **kwargs):
return [
SimpleNamespace(payload={"id": "e2"}, score=0.05),
SimpleNamespace(payload={"id": "e1"}, score=0.10),
]
async def get_context(self, query: str):
time_from, time_to = await self.extract_time_from_query(query)
if not (time_from or time_to):
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
relevant_events = await self._fake_graph_collect_events(ids)
_ = await self._fake_vector_embed([query])
vector_search_results = await self._fake_vector_search(
collection_name="Event_name", query_vector=[0.0], limit=0
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events)
# Test get_context fallback to triplets when no time is extracted
@pytest.mark.asyncio
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("no_time")
assert ctx == "edges->text"
assert tr._calls[0][0] == "get_triplets"
assert tr._calls[1][0] == "resolve_edges_to_text"
# Test get_context when time is extracted and vector ranking is applied
@pytest.mark.asyncio
async def test_fake_get_context_with_time_filters_and_vector_ranking():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("both time")
assert ctx.startswith("E2")
assert "#####################" in ctx
assert "E1" in ctx and "E3" not in ctx

View file

@ -1,608 +0,0 @@
import pytest
from unittest.mock import AsyncMock, patch
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
get_memory_fragment,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_empty_query():
"""Test that empty query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query="")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query=None)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_negative_top_k():
"""Test that negative top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=-1)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_zero_top_k():
"""Test that zero top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=0)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_global_search():
"""Test that wide_search_limit is applied for global search (node_name=None)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=None, # Global search
wide_search_top_k=75,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 75
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=["Node1"],
wide_search_top_k=50,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_default():
"""Test that wide_search_top_k defaults to 100."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", node_name=None)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 100
@pytest.mark.asyncio
async def test_brute_force_triplet_search_default_collections():
"""Test that default collections are used when none provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test")
expected_collections = [
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
"EdgeType_relationship_name",
]
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == expected_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_custom_collections():
"""Test that custom collections are used when provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
custom_collections = ["CustomCol1", "CustomCol2"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=custom_collections)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_always_includes_edge_collection():
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
collections_without_edge = ["Entity_name", "TextSummary_text"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=collections_without_edge)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert "EdgeType_relationship_name" in call_collections
assert set(call_collections) == set(collections_without_edge) | {
"EdgeType_relationship_name"
}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_all_collections_empty():
"""Test that empty list is returned when all collections return no results."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
results = await brute_force_triplet_search(query="test")
assert results == []
# Tests for query embedding
@pytest.mark.asyncio
async def test_brute_force_triplet_search_embeds_query():
"""Test that query is embedded before searching."""
query_text = "test query"
expected_vector = [0.1, 0.2, 0.3]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query=query_text)
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
for call in mock_vector_engine.search.call_args_list:
assert call[1]["query_vector"] == expected_vector
@pytest.mark.asyncio
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
"""Test that node IDs are extracted from search results for global search."""
scored_results = [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
MockScoredResult("node3", 0.92),
]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=scored_results)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_reuses_provided_fragment():
"""Test that provided memory fragment is reused instead of creating new one."""
provided_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
) as mock_get_fragment,
):
await brute_force_triplet_search(
query="test",
memory_fragment=provided_fragment,
node_name=["node"],
)
mock_get_fragment.assert_not_called()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
"""Test that memory fragment is created when not provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test", node_name=["node"])
mock_get_fragment.assert_called_once()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
"""Test that custom top_k is passed to importance calculation."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found")
)
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_error():
"""Test that get_memory_fragment returns empty graph on generic error."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_brute_force_triplet_search_deduplicates_node_ids():
"""Test that duplicate node IDs across collections are deduplicated."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
]
elif collection_name == "TextSummary_text":
return [
MockScoredResult("node1", 0.90),
MockScoredResult("node3", 0.92),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
@pytest.mark.asyncio
async def test_brute_force_triplet_search_excludes_edge_collection():
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "EdgeType_relationship_name":
return [MockScoredResult("edge1", 0.88)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(
query="test",
node_name=None,
collections=["Entity_name", "EdgeType_relationship_name"],
)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
@pytest.mark.asyncio
async def test_brute_force_triplet_search_skips_nodes_without_ids():
"""Test that nodes without ID attribute are skipped."""
class ScoredResultNoId:
"""Mock result without id attribute."""
def __init__(self, score):
self.score = score
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
ScoredResultNoId(0.90),
MockScoredResult("node2", 0.87),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_handles_tuple_results():
"""Test that both list and tuple results are handled correctly."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return (
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
)
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_mixed_empty_collections():
"""Test ID extraction with mixed empty and non-empty collections."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "TextSummary_text":
return []
elif collection_name == "EntityType_name":
return [MockScoredResult("node2", 0.92)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}

View file

@ -1,83 +0,0 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.has_collection = AsyncMock(return_value=True)
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of triplet context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Alice knows Bob"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Bob works at Tech Corp"}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = TripletRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
@pytest.mark.asyncio
async def test_get_context_no_collection(mock_vector_engine):
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
mock_vector_engine.has_collection.return_value = False
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty string is returned when no triplets are found."""
mock_vector_engine.search.return_value = []
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == ""
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")