chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881)
<!-- .github/pull_request_template.md -->
## Description
This PR restructures/adds integration and unit tests for the retrieval
module.
-Old integration tests were updated and moved under unit tests +
fixtures added
-Added missing unit tests for all core retrieval business logic
-Covered 100% of the core retrievers with tests
-Minor changes (dead code deletion, typo fixed)
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Changes**
* TripletRetriever now returns up to 5 results by default (was 1),
providing richer context.
* **Tests**
* Reorganized test coverage: many unit tests removed and replaced with
comprehensive integration tests across retrieval components (graph,
chunks, RAG, summaries, temporal, triplets, structured output).
* **Chores**
* Simplified triplet formatting logic and removed debug output.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
78028b819f
commit
4e8845c117
23 changed files with 1888 additions and 2303 deletions
|
|
@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
|
|||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
|
|
|
|||
|
|
@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
|
|||
|
||||
|
||||
def format_triplets(edges):
|
||||
print("\n\n\n")
|
||||
|
||||
def filter_attributes(obj, attributes):
|
||||
"""Helper function to filter out non-None properties, including nested dicts."""
|
||||
result = {}
|
||||
for attr in attributes:
|
||||
value = getattr(obj, attr, None)
|
||||
if value is not None:
|
||||
# If the value is a dict, extract relevant keys from it
|
||||
if isinstance(value, dict):
|
||||
nested_values = {
|
||||
k: v for k, v in value.items() if k in attributes and v is not None
|
||||
}
|
||||
result[attr] = nested_values
|
||||
else:
|
||||
result[attr] = value
|
||||
return result
|
||||
|
||||
triplets = []
|
||||
for edge in edges:
|
||||
node1 = edge.node1
|
||||
|
|
|
|||
252
cognee/tests/integration/retrieval/test_chunks_retriever.py
Normal file
252
cognee/tests/integration/retrieval/test_chunks_retriever.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
from typing import List
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_chunks_simple():
|
||||
"""Set up a clean test environment with simple chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_chunks_complex():
|
||||
"""Set up a clean test environment with complex chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
|
||||
"""Integration test: verify ChunksRetriever can retrieve multiple chunks."""
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
context = await retriever.get_context("Steve")
|
||||
|
||||
assert isinstance(context, list), "Context should be a list"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
|
||||
"Failed to get Steve Rodger chunk"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
|
||||
"""Integration test: verify ChunksRetriever respects top_k parameter."""
|
||||
retriever = ChunksRetriever(top_k=2)
|
||||
|
||||
context = await retriever.get_context("Employee")
|
||||
|
||||
assert isinstance(context, list), "Context should be a list"
|
||||
assert len(context) <= 2, "Should respect top_k limit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
|
||||
"""Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
|
||||
retriever = ChunksRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify ChunksRetriever handles empty graph correctly."""
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert len(context) == 0, "Found chunks when none should exist"
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
from typing import Optional, Union
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_simple():
|
||||
"""Set up a clean test environment with simple graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
company2 = Company(name="Canva", description="Canvas is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
person2 = Person(
|
||||
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||
)
|
||||
person3 = Person(
|
||||
name="Jason Statham",
|
||||
description="This is description about Jason Statham",
|
||||
works_for=company1,
|
||||
)
|
||||
person4 = Person(
|
||||
name="Mike Broski",
|
||||
description="This is description about Mike Broski",
|
||||
works_for=company2,
|
||||
)
|
||||
person5 = Person(
|
||||
name="Christina Mayer",
|
||||
description="This is description about Christina Mayer",
|
||||
works_for=company2,
|
||||
)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_complex():
|
||||
"""Set up a clean test environment with complex graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_simple(setup_test_environment_simple):
|
||||
"""Integration test: verify GraphCompletionRetriever can retrieve context (simple)."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
# Ensure the top-level sections are present
|
||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||
|
||||
# --- Nodes headers ---
|
||||
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||
|
||||
# --- Node contents ---
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Steve Rodger altered"
|
||||
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||
"Description block for Figma altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Ike Loma altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Jason Statham altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Mike Broski altered"
|
||||
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||
"Description block for Canva altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Christina Mayer altered"
|
||||
|
||||
# --- Connections ---
|
||||
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||
"Connection Steve Rodger→Figma missing or changed"
|
||||
)
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||
"Connection Ike Loma→Figma missing or changed"
|
||||
)
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||
"Connection Jason Statham→Figma missing or changed"
|
||||
)
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||
"Connection Mike Broski→Canva missing or changed"
|
||||
)
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||
"Connection Christina Mayer→Canva missing or changed"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(setup_test_environment_complex):
|
||||
"""Integration test: verify GraphCompletionRetriever can retrieve context (complex)."""
|
||||
retriever = GraphCompletionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionRetriever handles empty graph correctly."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_get_triplets_empty(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionRetriever get_triplets handles empty graph."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||
|
||||
assert isinstance(triplets, list), "Triplets should be a list"
|
||||
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
from typing import Optional, Union
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_simple():
|
||||
"""Set up a clean test environment with simple graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_graph_completion_extension_context_simple"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_graph_completion_extension_context_simple"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_complex():
|
||||
"""Set up a clean test environment with complex graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_graph_completion_extension_context_complex"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_graph_completion_extension_context_complex"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_simple(setup_test_environment_simple):
|
||||
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple)."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(setup_test_environment_complex):
|
||||
"""Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex)."""
|
||||
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(
|
||||
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||
)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||
|
||||
assert isinstance(triplets, list), "Triplets should be a list"
|
||||
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
from typing import Optional, Union
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_simple():
|
||||
"""Set up a clean test environment with simple graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_complex():
|
||||
"""Set up a clean test environment with complex graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_graph_completion_cot_context_complex"
|
||||
)
|
||||
data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without graph data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(setup_test_environment_simple):
|
||||
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (simple)."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(setup_test_environment_complex):
|
||||
"""Integration test: verify GraphCompletionCotRetriever can retrieve context (complex)."""
|
||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionCotRetriever handles empty graph correctly."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty):
|
||||
"""Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
triplets = await retriever.get_triplets("Who works at Figma?")
|
||||
|
||||
assert isinstance(triplets, list), "Triplets should be a list"
|
||||
assert len(triplets) == 0, "Should return empty list on empty graph"
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
import os
|
||||
from typing import List
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_chunks_simple():
|
||||
"""Set up a clean test environment with simple chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_chunks_complex():
|
||||
"""Set up a clean test environment with complex chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without chunks."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple):
|
||||
"""Integration test: verify CompletionRetriever can retrieve context (simple)."""
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert "Mike Broski" in context, "Failed to get Mike Broski"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple):
|
||||
"""Integration test: verify CompletionRetriever can retrieve context from multiple chunks."""
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Steve")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert "Steve Rodger" in context, "Failed to get Steve Rodger"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex):
|
||||
"""Integration test: verify CompletionRetriever can retrieve context (complex)."""
|
||||
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||
retriever = CompletionRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify CompletionRetriever handles empty graph correctly."""
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == "", "Returned context should be empty on an empty graph"
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import cognee
|
||||
import pathlib
|
||||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from pydantic import BaseModel
|
||||
from cognee.low_level import setup, DataPoint
|
||||
|
|
@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion():
|
|||
_assert_structured_answer(structured_answer)
|
||||
|
||||
|
||||
class TestStructuredOutputCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment():
|
||||
"""Set up a clean test environment with graph and document data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
works_since: int
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
await add_data_points(entities)
|
||||
|
||||
entity_type = EntityType(name="Person", description="A human individual")
|
||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||
|
||||
entities = [entity]
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
works_since: int
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
||||
|
||||
entities = [company1, person1]
|
||||
await add_data_points(entities)
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
await add_data_points(entities)
|
||||
|
||||
entity_type = EntityType(name="Person", description="A human individual")
|
||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||
|
||||
entities = [entity]
|
||||
await add_data_points(entities)
|
||||
|
||||
await _test_get_structured_graph_completion_cot()
|
||||
await _test_get_structured_graph_completion()
|
||||
await _test_get_structured_graph_completion_temporal()
|
||||
await _test_get_structured_graph_completion_rag()
|
||||
await _test_get_structured_graph_completion_context_extension()
|
||||
await _test_get_structured_entity_completion()
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_structured_completion(setup_test_environment):
|
||||
"""Integration test: verify structured output completion for all retrievers."""
|
||||
await _test_get_structured_graph_completion_cot()
|
||||
await _test_get_structured_graph_completion()
|
||||
await _test_get_structured_graph_completion_temporal()
|
||||
await _test_get_structured_graph_completion_rag()
|
||||
await _test_get_structured_graph_completion_context_extension()
|
||||
await _test_get_structured_entity_completion()
|
||||
184
cognee/tests/integration/retrieval/test_summaries_retriever.py
Normal file
184
cognee/tests/integration/retrieval/test_summaries_retriever.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_summaries():
|
||||
"""Set up a clean test environment with summaries."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk1_summary = TextSummary(
|
||||
text="S.R.",
|
||||
made_from=chunk1,
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2_summary = TextSummary(
|
||||
text="M.B.",
|
||||
made_from=chunk2,
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3_summary = TextSummary(
|
||||
text="C.M.",
|
||||
made_from=chunk3,
|
||||
)
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk4_summary = TextSummary(
|
||||
text="R.R.",
|
||||
made_from=chunk4,
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5_summary = TextSummary(
|
||||
text="H.Y.",
|
||||
made_from=chunk5,
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6_summary = TextSummary(
|
||||
text="C.H.",
|
||||
made_from=chunk6,
|
||||
)
|
||||
|
||||
entities = [
|
||||
chunk1_summary,
|
||||
chunk2_summary,
|
||||
chunk3_summary,
|
||||
chunk4_summary,
|
||||
chunk5_summary,
|
||||
chunk6_summary,
|
||||
]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without summaries."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summaries_retriever_context(setup_test_environment_with_summaries):
|
||||
"""Integration test: verify SummariesRetriever can retrieve summary context."""
|
||||
retriever = SummariesRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert isinstance(context, list), "Context should be a list"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify SummariesRetriever handles empty graph correctly."""
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
306
cognee/tests/integration/retrieval/test_temporal_retriever.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.engine.models.Event import Event
|
||||
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||
from cognee.modules.engine.models.Interval import Interval
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_events():
|
||||
"""Set up a clean test environment with temporal events."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
# Create timestamps for events
|
||||
timestamp1 = Timestamp(
|
||||
time_at=1609459200, # 2021-01-01 00:00:00
|
||||
year=2021,
|
||||
month=1,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-01-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp2 = Timestamp(
|
||||
time_at=1612137600, # 2021-02-01 00:00:00
|
||||
year=2021,
|
||||
month=2,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-02-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp3 = Timestamp(
|
||||
time_at=1614556800, # 2021-03-01 00:00:00
|
||||
year=2021,
|
||||
month=3,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-03-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp4 = Timestamp(
|
||||
time_at=1625097600, # 2021-07-01 00:00:00
|
||||
year=2021,
|
||||
month=7,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-07-01T00:00:00",
|
||||
)
|
||||
|
||||
timestamp5 = Timestamp(
|
||||
time_at=1633046400, # 2021-10-01 00:00:00
|
||||
year=2021,
|
||||
month=10,
|
||||
day=1,
|
||||
hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
timestamp_str="2021-10-01T00:00:00",
|
||||
)
|
||||
|
||||
# Create interval for event spanning multiple timestamps
|
||||
interval1 = Interval(time_from=timestamp2, time_to=timestamp3)
|
||||
|
||||
# Create events with timestamps
|
||||
event1 = Event(
|
||||
name="Project Alpha Launch",
|
||||
description="Launched Project Alpha at the beginning of 2021",
|
||||
at=timestamp1,
|
||||
location="San Francisco",
|
||||
)
|
||||
|
||||
event2 = Event(
|
||||
name="Team Meeting",
|
||||
description="Monthly team meeting discussing Q1 goals",
|
||||
during=interval1,
|
||||
location="New York",
|
||||
)
|
||||
|
||||
event3 = Event(
|
||||
name="Product Release",
|
||||
description="Released new product features in July",
|
||||
at=timestamp4,
|
||||
location="Remote",
|
||||
)
|
||||
|
||||
event4 = Event(
|
||||
name="Company Retreat",
|
||||
description="Annual company retreat in October",
|
||||
at=timestamp5,
|
||||
location="Lake Tahoe",
|
||||
)
|
||||
|
||||
entities = [event1, event2, event3, event4]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_graph_data():
|
||||
"""Set up a clean test environment with graph data (for fallback to triplets)."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
|
||||
entities = [company1, person1]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without data."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve events within time range."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("What happened in January 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Project Alpha" in context or "Launch" in context, (
|
||||
"Should retrieve Project Alpha Launch event from January 2021"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve events at specific time."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("What happened in July 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Product Release" in context or "July" in context, (
|
||||
"Should retrieve Product Release event from July 2021"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_fallback_to_triplets(
|
||||
setup_test_environment_with_graph_data,
|
||||
):
|
||||
"""Integration test: verify TemporalRetriever falls back to triplets when no time extracted."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
assert "Steve" in context or "Figma" in context, (
|
||||
"Should retrieve graph data via triplet search fallback"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty):
|
||||
"""Integration test: verify TemporalRetriever handles empty graph correctly."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
context = await retriever.get_context("What happened?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) >= 0, "Context should be a string (possibly empty)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_get_completion(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can generate completions."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
completion = await retriever.get_completion("What happened in January 2021?")
|
||||
|
||||
assert isinstance(completion, list), "Completion should be a list"
|
||||
assert len(completion) > 0, "Completion should not be empty"
|
||||
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||
"Completion items should be non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data):
|
||||
"""Integration test: verify TemporalRetriever get_completion works with triplet fallback."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
completion = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(completion, list), "Completion should be a list"
|
||||
assert len(completion) > 0, "Completion should not be empty"
|
||||
assert all(isinstance(item, str) and item.strip() for item in completion), (
|
||||
"Completion items should be non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever respects top_k parameter."""
|
||||
retriever = TemporalRetriever(top_k=2)
|
||||
|
||||
context = await retriever.get_context("What happened in 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
separator_count = context.count("#####################")
|
||||
assert separator_count <= 1, "Should respect top_k limit of 2 events"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_retriever_multiple_events(setup_test_environment_with_events):
|
||||
"""Integration test: verify TemporalRetriever can retrieve multiple events."""
|
||||
retriever = TemporalRetriever(top_k=10)
|
||||
|
||||
context = await retriever.get_context("What events occurred in 2021?")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
|
||||
assert (
|
||||
"Project Alpha" in context
|
||||
or "Team Meeting" in context
|
||||
or "Product Release" in context
|
||||
or "Company Retreat" in context
|
||||
), "Should retrieve at least one event from 2021"
|
||||
|
|
@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip
|
|||
context = await retriever.get_context("Alice")
|
||||
|
||||
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
assert len(context) > 0, "Context should not be empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets):
|
||||
"""Integration test: verify TripletRetriever can retrieve multiple triplets."""
|
||||
retriever = TripletRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("Bob")
|
||||
|
||||
assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, (
|
||||
"Failed to get Bob-related triplets"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets):
|
||||
"""Integration test: verify TripletRetriever respects top_k parameter."""
|
||||
retriever = TripletRetriever(top_k=1)
|
||||
|
||||
context = await retriever.get_context("Alice")
|
||||
|
||||
assert isinstance(context, str), "Context should be a string"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triplet_retriever_context_empty(setup_test_environment_empty):
|
||||
"""Integration test: verify TripletRetriever handles empty graph correctly."""
|
||||
await setup()
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Alice")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
|||
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
||||
"""
|
||||
|
||||
MOCK_HOTPOT_CORPUS = [
|
||||
{
|
||||
"_id": "1",
|
||||
"question": "Next to which country is Germany located?",
|
||||
"answer": "Netherlands",
|
||||
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||
"level": "easy",
|
||||
"type": "comparison",
|
||||
"context": [
|
||||
["Germany", ["Germany is in Europe."]],
|
||||
["Netherlands", ["The Netherlands borders Germany."]],
|
||||
],
|
||||
"supporting_facts": [["Netherlands", 0]],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
ADAPTER_CLASSES = [
|
||||
HotpotQAAdapter,
|
||||
|
|
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
|||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
||||
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
||||
else:
|
||||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
|
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
|||
):
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
else:
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,38 @@ import pytest
|
|||
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
||||
|
||||
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
||||
|
||||
MOCK_HOTPOT_CORPUS = [
|
||||
{
|
||||
"_id": "1",
|
||||
"question": "Next to which country is Germany located?",
|
||||
"answer": "Netherlands",
|
||||
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||
"level": "easy",
|
||||
"type": "comparison",
|
||||
"context": [
|
||||
["Germany", ["Germany is in Europe."]],
|
||||
["Netherlands", ["The Netherlands borders Germany."]],
|
||||
],
|
||||
"supporting_facts": [["Netherlands", 0]],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("benchmark", benchmark_options)
|
||||
def test_corpus_builder_load_corpus(benchmark):
|
||||
limit = 2
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
else:
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
|
||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||
assert len(questions) <= 2, (
|
||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||
|
|
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
|||
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
||||
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||
limit = 2
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
else:
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
|
||||
assert len(questions) <= 2, (
|
||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,201 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import List
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class TestChunksRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = ChunksRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert len(context) == 0, "Found chunks when none should exist"
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from cognee.context_global_variables import session_user
|
||||
import importlib
|
||||
|
||||
|
||||
def create_mock_cache_engine(qa_history=None):
|
||||
mock_cache = AsyncMock()
|
||||
if qa_history is None:
|
||||
qa_history = []
|
||||
mock_cache.get_latest_qa = AsyncMock(return_value=qa_history)
|
||||
mock_cache.add_qa = AsyncMock(return_value=None)
|
||||
return mock_cache
|
||||
|
||||
|
||||
def create_mock_user():
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id-123"
|
||||
return mock_user
|
||||
|
||||
|
||||
class TestConversationHistoryUtils:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_returns_empty_when_no_history(self):
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
from cognee.modules.retrieval.utils.session_cache import get_conversation_history
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_formats_history_correctly(self):
|
||||
"""Test get_conversation_history formats Q&A history with correct structure."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_history = [
|
||||
{
|
||||
"time": "2024-01-15 10:30:45",
|
||||
"question": "What is AI?",
|
||||
"context": "AI is artificial intelligence",
|
||||
"answer": "AI stands for Artificial Intelligence",
|
||||
}
|
||||
]
|
||||
mock_cache = create_mock_cache_engine(mock_history)
|
||||
|
||||
# Import the real module to patch safely
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert "Previous conversation:" in result
|
||||
assert "[2024-01-15 10:30:45]" in result
|
||||
assert "QUESTION: What is AI?" in result
|
||||
assert "CONTEXT: AI is artificial intelligence" in result
|
||||
assert "ANSWER: AI stands for Artificial Intelligence" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_to_session_cache_saves_correctly(self):
|
||||
"""Test save_conversation_history calls add_qa with correct parameters."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="What is Python?",
|
||||
context_summary="Python is a programming language",
|
||||
answer="Python is a high-level programming language",
|
||||
session_id="my_session",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_cache.add_qa.assert_called_once()
|
||||
|
||||
call_kwargs = mock_cache.add_qa.call_args.kwargs
|
||||
assert call_kwargs["question"] == "What is Python?"
|
||||
assert call_kwargs["context"] == "Python is a programming language"
|
||||
assert call_kwargs["answer"] == "Python is a high-level programming language"
|
||||
assert call_kwargs["session_id"] == "my_session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_to_session_cache_uses_default_session_when_none(self):
|
||||
"""Test save_conversation_history uses 'default_session' when session_id is None."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
call_kwargs = mock_cache.add_qa.call_args.kwargs
|
||||
assert call_kwargs["session_id"] == "default_session"
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphCompletionWithContextExtensionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(
|
||||
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||
)
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionCoTRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_cot_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
company2 = Company(name="Canva", description="Canvas is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
person2 = Person(
|
||||
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||
)
|
||||
person3 = Person(
|
||||
name="Jason Statham",
|
||||
description="This is description about Jason Statham",
|
||||
works_for=company1,
|
||||
)
|
||||
person4 = Person(
|
||||
name="Mike Broski",
|
||||
description="This is description about Mike Broski",
|
||||
works_for=company2,
|
||||
)
|
||||
person5 = Person(
|
||||
name="Christina Mayer",
|
||||
description="This is description about Christina Mayer",
|
||||
works_for=company2,
|
||||
)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
# Ensure the top-level sections are present
|
||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||
|
||||
# --- Nodes headers ---
|
||||
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||
|
||||
# --- Node contents ---
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Steve Rodger altered"
|
||||
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||
"Description block for Figma altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Ike Loma altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Jason Statham altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Mike Broski altered"
|
||||
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||
"Description block for Canva altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Christina Mayer altered"
|
||||
|
||||
# --- Connections ---
|
||||
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||
"Connection Steve Rodger→Figma missing or changed"
|
||||
)
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||
"Connection Ike Loma→Figma missing or changed"
|
||||
)
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||
"Connection Jason Statham→Figma missing or changed"
|
||||
)
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||
"Connection Mike Broski→Canva missing or changed"
|
||||
)
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||
"Connection Christina Mayer→Canva missing or changed"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = GraphCompletionRetriever(top_k=20)
|
||||
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
import os
|
||||
from typing import List
|
||||
import pytest
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunkWithEntities(DataPoint):
|
||||
text: str
|
||||
chunk_size: int
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class TestRAGCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||
retriever = CompletionRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rag_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection(
|
||||
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
||||
)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == "", "Returned context should be empty on an empty graph"
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
class TestSummariesRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk1_summary = TextSummary(
|
||||
text="S.R.",
|
||||
made_from=chunk1,
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2_summary = TextSummary(
|
||||
text="M.B.",
|
||||
made_from=chunk2,
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3_summary = TextSummary(
|
||||
text="C.M.",
|
||||
made_from=chunk3,
|
||||
)
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk4_summary = TextSummary(
|
||||
text="R.R.",
|
||||
made_from=chunk4,
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5_summary = TextSummary(
|
||||
text="H.Y.",
|
||||
made_from=chunk5,
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6_summary = TextSummary(
|
||||
text="C.H.",
|
||||
made_from=chunk6,
|
||||
)
|
||||
|
||||
entities = [
|
||||
chunk1_summary,
|
||||
chunk2_summary,
|
||||
chunk3_summary,
|
||||
chunk4_summary,
|
||||
chunk5_summary,
|
||||
chunk6_summary,
|
||||
]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = SummariesRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
|
||||
|
||||
# Test TemporalRetriever initialization defaults and overrides
|
||||
def test_init_defaults_and_overrides():
|
||||
tr = TemporalRetriever()
|
||||
assert tr.top_k == 5
|
||||
assert tr.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert tr.system_prompt_path == "answer_simple_question.txt"
|
||||
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
|
||||
|
||||
tr2 = TemporalRetriever(
|
||||
top_k=3,
|
||||
user_prompt_path="u.txt",
|
||||
system_prompt_path="s.txt",
|
||||
time_extraction_prompt_path="t.txt",
|
||||
)
|
||||
assert tr2.top_k == 3
|
||||
assert tr2.user_prompt_path == "u.txt"
|
||||
assert tr2.system_prompt_path == "s.txt"
|
||||
assert tr2.time_extraction_prompt_path == "t.txt"
|
||||
|
||||
|
||||
# Test descriptions_to_string with basic and empty results
|
||||
def test_descriptions_to_string_basic_and_empty():
|
||||
tr = TemporalRetriever()
|
||||
|
||||
results = [
|
||||
{"description": " First "},
|
||||
{"nope": "no description"},
|
||||
{"description": "Second"},
|
||||
{"description": ""},
|
||||
{"description": " Third line "},
|
||||
]
|
||||
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert s == "First\n#####################\nSecond\n#####################\nThird line"
|
||||
|
||||
assert tr.descriptions_to_string([]) == ""
|
||||
|
||||
|
||||
# Test filter_top_k_events sorts and limits correctly
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_sorts_and_limits():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3 - not in vector results"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
|
||||
assert [e["id"] for e in top] == ["e2", "e1"]
|
||||
assert all("score" in e for e in top)
|
||||
assert top[0]["score"] == 0.10
|
||||
assert top[1]["score"] == 0.20
|
||||
|
||||
|
||||
# Test filter_top_k_events handles unknown ids as infinite scores
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "known1", "description": "Known 1"},
|
||||
{"id": "unknown", "description": "Unknown"},
|
||||
{"id": "known2", "description": "Known 2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in top] == ["known2", "known1"]
|
||||
assert all(e["score"] != float("inf") for e in top)
|
||||
|
||||
|
||||
# Test descriptions_to_string with unicode and newlines
|
||||
def test_descriptions_to_string_unicode_and_newlines():
|
||||
tr = TemporalRetriever()
|
||||
results = [
|
||||
{"description": "Line A\nwith newline"},
|
||||
{"description": "This is a description"},
|
||||
]
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert "Line A\nwith newline" in s
|
||||
assert "This is a description" in s
|
||||
assert s.count("#####################") == 1
|
||||
|
||||
|
||||
# Test filter_top_k_events when top_k is larger than available events
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
||||
tr = TemporalRetriever(top_k=10)
|
||||
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
||||
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
||||
]
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["a", "b"]
|
||||
|
||||
|
||||
# Test filter_top_k_events when scored_results is empty
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_handles_empty_scored_results():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
|
||||
scored_results = []
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["x", "y"]
|
||||
assert all(e["score"] == float("inf") for e in out)
|
||||
|
||||
|
||||
# Test filter_top_k_events error handling for missing structure
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_error_handling():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
with pytest.raises((KeyError, TypeError)):
|
||||
await tr.filter_top_k_events([{}], [])
|
||||
|
||||
|
||||
class _FakeRetriever(TemporalRetriever):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._calls = []
|
||||
|
||||
async def extract_time_from_query(self, query: str):
|
||||
if "both" in query:
|
||||
return "2024-01-01", "2024-12-31"
|
||||
if "from_only" in query:
|
||||
return "2024-01-01", None
|
||||
if "to_only" in query:
|
||||
return None, "2024-12-31"
|
||||
return None, None
|
||||
|
||||
async def get_triplets(self, query: str):
|
||||
self._calls.append(("get_triplets", query))
|
||||
return [{"s": "a", "p": "b", "o": "c"}]
|
||||
|
||||
async def resolve_edges_to_text(self, triplets):
|
||||
self._calls.append(("resolve_edges_to_text", len(triplets)))
|
||||
return "edges->text"
|
||||
|
||||
async def _fake_graph_collect_ids(self, **kwargs):
|
||||
return ["e1", "e2"]
|
||||
|
||||
async def _fake_graph_collect_events(self, ids):
|
||||
return [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
async def _fake_vector_embed(self, texts):
|
||||
assert isinstance(texts, list) and texts
|
||||
return [[0.0, 1.0, 2.0]]
|
||||
|
||||
async def _fake_vector_search(self, **kwargs):
|
||||
return [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.10),
|
||||
]
|
||||
|
||||
async def get_context(self, query: str):
|
||||
time_from, time_to = await self.extract_time_from_query(query)
|
||||
|
||||
if not (time_from or time_to):
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
|
||||
relevant_events = await self._fake_graph_collect_events(ids)
|
||||
|
||||
_ = await self._fake_vector_embed([query])
|
||||
vector_search_results = await self._fake_vector_search(
|
||||
collection_name="Event_name", query_vector=[0.0], limit=0
|
||||
)
|
||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
|
||||
# Test get_context fallback to triplets when no time is extracted
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("no_time")
|
||||
assert ctx == "edges->text"
|
||||
assert tr._calls[0][0] == "get_triplets"
|
||||
assert tr._calls[1][0] == "resolve_edges_to_text"
|
||||
|
||||
|
||||
# Test get_context when time is extracted and vector ranking is applied
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_get_context_with_time_filters_and_vector_ranking():
|
||||
tr = _FakeRetriever(top_k=2)
|
||||
ctx = await tr.get_context("both time")
|
||||
assert ctx.startswith("E2")
|
||||
assert "#####################" in ctx
|
||||
assert "E1" in ctx and "E3" not in ctx
|
||||
|
|
@ -1,608 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_empty_query():
|
||||
"""Test that empty query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query="")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_negative_top_k():
|
||||
"""Test that negative top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=-1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_zero_top_k():
|
||||
"""Test that zero top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
||||
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None, # Global search
|
||||
wide_search_top_k=75,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 75
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
||||
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=["Node1"],
|
||||
wide_search_top_k=50,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_default():
|
||||
"""Test that wide_search_top_k defaults to 100."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_default_collections():
|
||||
"""Test that default collections are used when none provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
||||
expected_collections = [
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == expected_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_custom_collections():
|
||||
"""Test that custom collections are used when provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert "EdgeType_relationship_name" in call_collections
|
||||
assert set(call_collections) == set(collections_without_edge) | {
|
||||
"EdgeType_relationship_name"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_all_collections_empty():
|
||||
"""Test that empty list is returned when all collections return no results."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
assert results == []
|
||||
|
||||
|
||||
# Tests for query embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_embeds_query():
|
||||
"""Test that query is embedded before searching."""
|
||||
query_text = "test query"
|
||||
expected_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["query_vector"] == expected_vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
||||
"""Test that node IDs are extracted from search results for global search."""
|
||||
scored_results = [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
||||
"""Test that provided memory fragment is reused instead of creating new one."""
|
||||
provided_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
memory_fragment=provided_fragment,
|
||||
node_name=["node"],
|
||||
)
|
||||
|
||||
mock_get_fragment.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
||||
"""Test that memory fragment is created when not provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=["node"])
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
||||
"""Test that custom top_k is passed to importance calculation."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
||||
"""Test that get_memory_fragment returns empty graph on generic error."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
||||
"""Test that duplicate node IDs across collections are deduplicated."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return [
|
||||
MockScoredResult("node1", 0.90),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_excludes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [MockScoredResult("edge1", 0.88)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None,
|
||||
collections=["Entity_name", "EdgeType_relationship_name"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
||||
"""Test that nodes without ID attribute are skipped."""
|
||||
|
||||
class ScoredResultNoId:
|
||||
"""Mock result without id attribute."""
|
||||
|
||||
def __init__(self, score):
|
||||
self.score = score
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
ScoredResultNoId(0.90),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_handles_tuple_results():
|
||||
"""Test that both list and tuple results are handled correctly."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return (
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_mixed_empty_collections():
|
||||
"""Test ID extraction with mixed empty and non-empty collections."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return []
|
||||
elif collection_name == "EntityType_name":
|
||||
return [MockScoredResult("node2", 0.92)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.has_collection = AsyncMock(return_value=True)
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of triplet context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "Alice knows Bob"}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "Bob works at Tech Corp"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = TripletRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
||||
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_no_collection(mock_vector_engine):
|
||||
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
||||
mock_vector_engine.has_collection.return_value = False
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty string is returned when no triplets are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
Loading…
Add table
Reference in a new issue