From 4e8845c117ecf892c3f5554c94de4f9f1171b9ff Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:11:29 +0100 Subject: [PATCH] chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR restructures/adds integration and unit tests for the retrieval module. -Old integration tests were updated and moved under unit tests + fixtures added -Added missing unit tests for all core retrieval business logic -Covered 100% of the core retrievers with tests -Minor changes (dead code deletion, typo fixed) ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Changes** * TripletRetriever now returns up to 5 results by default (was 1), providing richer context. * **Tests** * Reorganized test coverage: many unit tests removed and replaced with comprehensive integration tests across retrieval components (graph, chunks, RAG, summaries, temporal, triplets, structured output). * **Chores** * Simplified triplet formatting logic and removed debug output. ✏️ Tip: You can customize this high-level summary in your review settings. --- cognee/modules/retrieval/triplet_retriever.py | 2 +- .../utils/brute_force_triplet_search.py | 18 - .../retrieval/test_chunks_retriever.py | 252 ++++++++ .../test_graph_completion_retriever.py | 268 ++++++++ ..._completion_retriever_context_extension.py | 226 +++++++ .../test_graph_completion_retriever_cot.py | 218 +++++++ .../test_rag_completion_retriever.py | 254 ++++++++ .../retrieval/test_structured_output.py} | 162 ++--- .../retrieval/test_summaries_retriever.py | 184 ++++++ .../retrieval/test_temporal_retriever.py | 306 +++++++++ .../retrieval/test_triplet_retriever.py | 35 + .../eval_framework/benchmark_adapters_test.py | 25 + .../eval_framework/corpus_builder_test.py | 37 +- .../retrieval/chunks_retriever_test.py | 201 ------ .../retrieval/conversation_history_test.py | 154 ----- ...letion_retriever_context_extension_test.py | 177 ----- .../graph_completion_retriever_cot_test.py | 170 ----- .../graph_completion_retriever_test.py | 223 ------- .../rag_completion_retriever_test.py | 205 ------ .../retrieval/summaries_retriever_test.py | 159 ----- .../retrieval/temporal_retriever_test.py | 224 ------- .../test_brute_force_triplet_search.py | 608 ------------------ .../retrieval/triplet_retriever_test.py | 83 --- 23 files changed, 1888 insertions(+), 2303 deletions(-) create mode 100644 cognee/tests/integration/retrieval/test_chunks_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py create mode 100644 cognee/tests/integration/retrieval/test_rag_completion_retriever.py rename cognee/tests/{unit/modules/retrieval/structured_output_test.py => integration/retrieval/test_structured_output.py} (65%) create mode 100644 cognee/tests/integration/retrieval/test_summaries_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_temporal_retriever.py delete mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/conversation_history_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/summaries_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py delete mode 100644 cognee/tests/unit/modules/retrieval/triplet_retriever_test.py diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py index d251d113a..b9d006312 100644 --- a/cognee/modules/retrieval/triplet_retriever.py +++ b/cognee/modules/retrieval/triplet_retriever.py @@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path - self.top_k = top_k if top_k is not None else 1 + self.top_k = top_k if top_k is not None else 5 self.system_prompt = system_prompt async def get_context(self, query: str) -> str: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index bd412e0ca..a70fa661b 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -16,24 +16,6 @@ logger = get_logger(level=ERROR) def format_triplets(edges): - print("\n\n\n") - - def filter_attributes(obj, attributes): - """Helper function to filter out non-None properties, including nested dicts.""" - result = {} - for attr in attributes: - value = getattr(obj, attr, None) - if value is not None: - # If the value is a dict, extract relevant keys from it - if isinstance(value, dict): - nested_values = { - k: v for k, v in value.items() if k in attributes and v is not None - } - result[attr] = nested_values - else: - result[attr] = value - return result - triplets = [] for edge in edges: node1 = edge.node1 diff --git a/cognee/tests/integration/retrieval/test_chunks_retriever.py b/cognee/tests/integration/retrieval/test_chunks_retriever.py new file mode 100644 index 000000000..d2e5e6149 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_chunks_retriever.py @@ -0,0 +1,252 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import List +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify ChunksRetriever can retrieve multiple chunks.""" + retriever = ChunksRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert any(chunk["text"] == "Steve Rodger" for chunk in context), ( + "Failed to get Steve Rodger chunk" + ) + + +@pytest.mark.asyncio +async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever respects top_k parameter.""" + retriever = ChunksRetriever(top_k=2) + + context = await retriever.get_context("Employee") + + assert isinstance(context, list), "Context should be a list" + assert len(context) <= 2, "Should respect top_k limit" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever can retrieve chunk context (complex).""" + retriever = ChunksRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify ChunksRetriever handles empty graph correctly.""" + retriever = ChunksRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py new file mode 100644 index 000000000..7367b353b --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py @@ -0,0 +1,268 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + company2 = Company(name="Canva", description="Canvas is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + person2 = Person( + name="Ike Loma", description="This is description about Ike Loma", works_for=company1 + ) + person3 = Person( + name="Jason Statham", + description="This is description about Jason Statham", + works_for=company1, + ) + person4 = Person( + name="Mike Broski", + description="This is description about Mike Broski", + works_for=company2, + ) + person5 = Person( + name="Christina Mayer", + description="This is description about Christina Mayer", + works_for=company2, + ) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionRetriever can retrieve context (simple).""" + retriever = GraphCompletionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + # Ensure the top-level sections are present + assert "Nodes:" in context, "Missing 'Nodes:' section in context" + assert "Connections:" in context, "Missing 'Connections:' section in context" + + # --- Nodes headers --- + assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" + assert "Node: Figma" in context, "Missing node header for Figma" + assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" + assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" + assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" + assert "Node: Canva" in context, "Missing node header for Canva" + assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" + + # --- Node contents --- + assert ( + "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" + in context + ), "Description block for Steve Rodger altered" + assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( + "Description block for Figma altered" + ) + assert ( + "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" + in context + ), "Description block for Ike Loma altered" + assert ( + "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" + in context + ), "Description block for Jason Statham altered" + assert ( + "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" + in context + ), "Description block for Mike Broski altered" + assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( + "Description block for Canva altered" + ) + assert ( + "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" + in context + ), "Description block for Christina Mayer altered" + + # --- Connections --- + assert "Steve Rodger --[works_for]--> Figma" in context, ( + "Connection Steve Rodger→Figma missing or changed" + ) + assert "Ike Loma --[works_for]--> Figma" in context, ( + "Connection Ike Loma→Figma missing or changed" + ) + assert "Jason Statham --[works_for]--> Figma" in context, ( + "Connection Jason Statham→Figma missing or changed" + ) + assert "Mike Broski --[works_for]--> Canva" in context, ( + "Connection Mike Broski→Canva missing or changed" + ) + assert "Christina Mayer --[works_for]--> Canva" in context, ( + "Connection Christina Mayer→Canva missing or changed" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionRetriever can retrieve context (complex).""" + retriever = GraphCompletionRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + +@pytest.mark.asyncio +async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever handles empty graph correctly.""" + retriever = GraphCompletionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + +@pytest.mark.asyncio +async def test_graph_completion_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py new file mode 100644 index 000000000..c87de16ef --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py @@ -0,0 +1,226 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_simple" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_simple" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_complex" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_complex" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple).""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex).""" + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await resolve_edges_to_text( + await retriever.get_context("Who works at Figma and drives Tesla?") + ) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly.""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionContextExtensionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py new file mode 100644 index 000000000..0db035e03 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -0,0 +1,218 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_simple" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_complex" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (simple).""" + retriever = GraphCompletionCotRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (complex).""" + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever handles empty graph correctly.""" + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionCotRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_rag_completion_retriever.py b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py new file mode 100644 index 000000000..b01d58160 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py @@ -0,0 +1,254 @@ +import os +from typing import List +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context (simple).""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Mike") + + assert isinstance(context, str), "Context should be a string" + assert "Mike Broski" in context, "Failed to get Mike Broski" + + +@pytest.mark.asyncio +async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context from multiple chunks.""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, str), "Context should be a string" + assert "Steve Rodger" in context, "Failed to get Steve Rodger" + + +@pytest.mark.asyncio +async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify CompletionRetriever can retrieve context (complex).""" + # TODO: top_k doesn't affect the output, it should be fixed. + retriever = CompletionRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify CompletionRetriever handles empty graph correctly.""" + retriever = CompletionRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/integration/retrieval/test_structured_output.py similarity index 65% rename from cognee/tests/unit/modules/retrieval/structured_output_test.py rename to cognee/tests/integration/retrieval/test_structured_output.py index 4ad3019ff..13ffd8eef 100644 --- a/cognee/tests/unit/modules/retrieval/structured_output_test.py +++ b/cognee/tests/integration/retrieval/test_structured_output.py @@ -1,9 +1,9 @@ import asyncio - -import pytest -import cognee -import pathlib import os +import pytest +import pathlib +import pytest_asyncio +import cognee from pydantic import BaseModel from cognee.low_level import setup, DataPoint @@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion(): _assert_structured_answer(structured_answer) -class TestStructuredOutputCompletion: - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest_asyncio.fixture +async def setup_test_environment(): + """Set up a clean test environment with graph and document data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion") + data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion") + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + yield + + try: await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await setup() + except Exception: + pass - class Company(DataPoint): - name: str - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - await _test_get_structured_graph_completion_cot() - await _test_get_structured_graph_completion() - await _test_get_structured_graph_completion_temporal() - await _test_get_structured_graph_completion_rag() - await _test_get_structured_graph_completion_context_extension() - await _test_get_structured_entity_completion() +@pytest.mark.asyncio +async def test_get_structured_completion(setup_test_environment): + """Integration test: verify structured output completion for all retrievers.""" + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/integration/retrieval/test_summaries_retriever.py b/cognee/tests/integration/retrieval/test_summaries_retriever.py new file mode 100644 index 000000000..a2f4e40b3 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_summaries_retriever.py @@ -0,0 +1,184 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.tasks.summarization.models import TextSummary +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_with_summaries(): + """Set up a clean test environment with summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk1_summary = TextSummary( + text="S.R.", + made_from=chunk1, + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2_summary = TextSummary( + text="M.B.", + made_from=chunk2, + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3_summary = TextSummary( + text="C.M.", + made_from=chunk3, + ) + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk4_summary = TextSummary( + text="R.R.", + made_from=chunk4, + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5_summary = TextSummary( + text="H.Y.", + made_from=chunk5, + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6_summary = TextSummary( + text="C.H.", + made_from=chunk6, + ) + + entities = [ + chunk1_summary, + chunk2_summary, + chunk3_summary, + chunk4_summary, + chunk5_summary, + chunk6_summary, + ] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_summaries_retriever_context(setup_test_environment_with_summaries): + """Integration test: verify SummariesRetriever can retrieve summary context.""" + retriever = SummariesRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify SummariesRetriever handles empty graph correctly.""" + retriever = SummariesRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) + + context = await retriever.get_context("Christina Mayer") + assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/integration/retrieval/test_temporal_retriever.py b/cognee/tests/integration/retrieval/test_temporal_retriever.py new file mode 100644 index 000000000..8ce3b32f4 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_temporal_retriever.py @@ -0,0 +1,306 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.engine.models.Event import Event +from cognee.modules.engine.models.Timestamp import Timestamp +from cognee.modules.engine.models.Interval import Interval + + +@pytest_asyncio.fixture +async def setup_test_environment_with_events(): + """Set up a clean test environment with temporal events.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + # Create timestamps for events + timestamp1 = Timestamp( + time_at=1609459200, # 2021-01-01 00:00:00 + year=2021, + month=1, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-01-01T00:00:00", + ) + + timestamp2 = Timestamp( + time_at=1612137600, # 2021-02-01 00:00:00 + year=2021, + month=2, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-02-01T00:00:00", + ) + + timestamp3 = Timestamp( + time_at=1614556800, # 2021-03-01 00:00:00 + year=2021, + month=3, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-03-01T00:00:00", + ) + + timestamp4 = Timestamp( + time_at=1625097600, # 2021-07-01 00:00:00 + year=2021, + month=7, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-07-01T00:00:00", + ) + + timestamp5 = Timestamp( + time_at=1633046400, # 2021-10-01 00:00:00 + year=2021, + month=10, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-10-01T00:00:00", + ) + + # Create interval for event spanning multiple timestamps + interval1 = Interval(time_from=timestamp2, time_to=timestamp3) + + # Create events with timestamps + event1 = Event( + name="Project Alpha Launch", + description="Launched Project Alpha at the beginning of 2021", + at=timestamp1, + location="San Francisco", + ) + + event2 = Event( + name="Team Meeting", + description="Monthly team meeting discussing Q1 goals", + during=interval1, + location="New York", + ) + + event3 = Event( + name="Product Release", + description="Released new product features in July", + at=timestamp4, + location="Remote", + ) + + event4 = Event( + name="Company Retreat", + description="Annual company retreat in October", + at=timestamp5, + location="Lake Tahoe", + ) + + entities = [event1, event2, event3, event4] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_graph_data(): + """Set up a clean test environment with graph data (for fallback to triplets).""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + + entities = [company1, person1] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events within time range.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in January 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Project Alpha" in context or "Launch" in context, ( + "Should retrieve Project Alpha Launch event from January 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events at specific time.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in July 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Product Release" in context or "July" in context, ( + "Should retrieve Product Release event from July 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_fallback_to_triplets( + setup_test_environment_with_graph_data, +): + """Integration test: verify TemporalRetriever falls back to triplets when no time extracted.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("Who works at Figma?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Steve" in context or "Figma" in context, ( + "Should retrieve graph data via triplet search fallback" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty): + """Integration test: verify TemporalRetriever handles empty graph correctly.""" + retriever = TemporalRetriever() + + context = await retriever.get_context("What happened?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) >= 0, "Context should be a string (possibly empty)" + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can generate completions.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("What happened in January 2021?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data): + """Integration test: verify TemporalRetriever get_completion works with triplet fallback.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("Who works at Figma?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever respects top_k parameter.""" + retriever = TemporalRetriever(top_k=2) + + context = await retriever.get_context("What happened in 2021?") + + assert isinstance(context, str), "Context should be a string" + separator_count = context.count("#####################") + assert separator_count <= 1, "Should respect top_k limit of 2 events" + + +@pytest.mark.asyncio +async def test_temporal_retriever_multiple_events(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve multiple events.""" + retriever = TemporalRetriever(top_k=10) + + context = await retriever.get_context("What events occurred in 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + assert ( + "Project Alpha" in context + or "Team Meeting" in context + or "Product Release" in context + or "Company Retreat" in context + ), "Should retrieve at least one event from 2021" diff --git a/cognee/tests/integration/retrieval/test_triplet_retriever.py b/cognee/tests/integration/retrieval/test_triplet_retriever.py index e547b6cbe..ebe853e08 100644 --- a/cognee/tests/integration/retrieval/test_triplet_retriever.py +++ b/cognee/tests/integration/retrieval/test_triplet_retriever.py @@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip context = await retriever.get_context("Alice") assert "Alice knows Bob" in context, "Failed to get Alice triplet" + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever can retrieve multiple triplets.""" + retriever = TripletRetriever(top_k=5) + + context = await retriever.get_context("Bob") + + assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, ( + "Failed to get Bob-related triplets" + ) + + +@pytest.mark.asyncio +async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever respects top_k parameter.""" + retriever = TripletRetriever(top_k=1) + + context = await retriever.get_context("Alice") + + assert isinstance(context, str), "Context should be a string" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_empty(setup_test_environment_empty): + """Integration test: verify TripletRetriever handles empty graph correctly.""" + await setup() + + retriever = TripletRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Alice") diff --git a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py index 70ec43cf8..b18012594 100644 --- a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py +++ b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py @@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\ {"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]} """ +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + ADAPTER_CLASSES = [ HotpotQAAdapter, @@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass): adapter = AdapterClass() result = adapter.load_corpus() + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + result = adapter.load_corpus() + else: adapter = AdapterClass() result = adapter.load_corpus() @@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass): ): adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + corpus_list, qa_pairs = adapter.load_corpus(limit=limit) else: adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) diff --git a/cognee/tests/unit/eval_framework/corpus_builder_test.py b/cognee/tests/unit/eval_framework/corpus_builder_test.py index 14136bea5..53f886b58 100644 --- a/cognee/tests/unit/eval_framework/corpus_builder_test.py +++ b/cognee/tests/unit/eval_framework/corpus_builder_test.py @@ -2,15 +2,38 @@ import pytest from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor from cognee.infrastructure.databases.graph import get_graph_engine from unittest.mock import AsyncMock, patch +from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"] +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + @pytest.mark.parametrize("benchmark", benchmark_options) def test_corpus_builder_load_corpus(benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" @@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark): @patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock) async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - questions = await corpus_builder.build_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" ) diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py deleted file mode 100644 index 44786f79d..000000000 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import pytest -import pathlib -from typing import List -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestChunksRetriever: - @pytest.mark.asyncio - async def test_chunk_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = ChunksRetriever() - - context = await retriever.get_context("Mike") - - assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_chunk_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - retriever = ChunksRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = ChunksRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py deleted file mode 100644 index d464a99d8..000000000 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ /dev/null @@ -1,154 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from cognee.context_global_variables import session_user -import importlib - - -def create_mock_cache_engine(qa_history=None): - mock_cache = AsyncMock() - if qa_history is None: - qa_history = [] - mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) - mock_cache.add_qa = AsyncMock(return_value=None) - return mock_cache - - -def create_mock_user(): - mock_user = MagicMock() - mock_user.id = "test-user-id-123" - return mock_user - - -class TestConversationHistoryUtils: - @pytest.mark.asyncio - async def test_get_conversation_history_returns_empty_when_no_history(self): - user = create_mock_user() - session_user.set(user) - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - from cognee.modules.retrieval.utils.session_cache import get_conversation_history - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_formats_history_correctly(self): - """Test get_conversation_history formats Q&A history with correct structure.""" - user = create_mock_user() - session_user.set(user) - - mock_history = [ - { - "time": "2024-01-15 10:30:45", - "question": "What is AI?", - "context": "AI is artificial intelligence", - "answer": "AI stands for Artificial Intelligence", - } - ] - mock_cache = create_mock_cache_engine(mock_history) - - # Import the real module to patch safely - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert "Previous conversation:" in result - assert "[2024-01-15 10:30:45]" in result - assert "QUESTION: What is AI?" in result - assert "CONTEXT: AI is artificial intelligence" in result - assert "ANSWER: AI stands for Artificial Intelligence" in result - - @pytest.mark.asyncio - async def test_save_to_session_cache_saves_correctly(self): - """Test save_conversation_history calls add_qa with correct parameters.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="What is Python?", - context_summary="Python is a programming language", - answer="Python is a high-level programming language", - session_id="my_session", - ) - - assert result is True - mock_cache.add_qa.assert_called_once() - - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["question"] == "What is Python?" - assert call_kwargs["context"] == "Python is a programming language" - assert call_kwargs["answer"] == "Python is a high-level programming language" - assert call_kwargs["session_id"] == "my_session" - - @pytest.mark.asyncio - async def test_save_to_session_cache_uses_default_session_when_none(self): - """Test save_conversation_history uses 'default_session' when session_id is None.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - session_id=None, - ) - - assert result is True - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["session_id"] == "default_session" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py deleted file mode 100644 index 0e21fe351..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) - - -class TestGraphCompletionWithContextExtensionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_extension_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_simple", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_simple", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_extension_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_complex", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever(top_k=20) - - context = await resolve_edges_to_text( - await retriever.get_context("Who works at Figma and drives Tesla?") - ) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_extension_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionContextExtensionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py deleted file mode 100644 index 206cfaf84..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever - - -class TestGraphCompletionCoTRetriever: - @pytest.mark.asyncio - async def test_graph_completion_cot_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_cot_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_cot_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_cot_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionCotRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py deleted file mode 100644 index f462baced..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ /dev/null @@ -1,223 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever - - -class TestGraphCompletionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - description: str - - class Person(DataPoint): - name: str - description: str - works_for: Company - - company1 = Company(name="Figma", description="Figma is a company") - company2 = Company(name="Canva", description="Canvas is a company") - person1 = Person( - name="Steve Rodger", - description="This is description about Steve Rodger", - works_for=company1, - ) - person2 = Person( - name="Ike Loma", description="This is description about Ike Loma", works_for=company1 - ) - person3 = Person( - name="Jason Statham", - description="This is description about Jason Statham", - works_for=company1, - ) - person4 = Person( - name="Mike Broski", - description="This is description about Mike Broski", - works_for=company2, - ) - person5 = Person( - name="Christina Mayer", - description="This is description about Christina Mayer", - works_for=company2, - ) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - # Ensure the top-level sections are present - assert "Nodes:" in context, "Missing 'Nodes:' section in context" - assert "Connections:" in context, "Missing 'Connections:' section in context" - - # --- Nodes headers --- - assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" - assert "Node: Figma" in context, "Missing node header for Figma" - assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" - assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" - assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" - assert "Node: Canva" in context, "Missing node header for Canva" - assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" - - # --- Node contents --- - assert ( - "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" - in context - ), "Description block for Steve Rodger altered" - assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( - "Description block for Figma altered" - ) - assert ( - "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" - in context - ), "Description block for Ike Loma altered" - assert ( - "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" - in context - ), "Description block for Jason Statham altered" - assert ( - "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" - in context - ), "Description block for Mike Broski altered" - assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( - "Description block for Canva altered" - ) - assert ( - "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" - in context - ), "Description block for Christina Mayer altered" - - # --- Connections --- - assert "Steve Rodger --[works_for]--> Figma" in context, ( - "Connection Steve Rodger→Figma missing or changed" - ) - assert "Ike Loma --[works_for]--> Figma" in context, ( - "Connection Ike Loma→Figma missing or changed" - ) - assert "Jason Statham --[works_for]--> Figma" in context, ( - "Connection Jason Statham→Figma missing or changed" - ) - assert "Mike Broski --[works_for]--> Canva" in context, ( - "Connection Mike Broski→Canva missing or changed" - ) - assert "Christina Mayer --[works_for]--> Canva" in context, ( - "Connection Christina Mayer→Canva missing or changed" - ) - - @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py deleted file mode 100644 index 9bfed68f3..000000000 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ /dev/null @@ -1,205 +0,0 @@ -import os -from typing import List -import pytest -import pathlib -import cognee - -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestRAGCompletionRetriever: - @pytest.mark.asyncio - async def test_rag_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = CompletionRetriever() - - context = await retriever.get_context("Mike") - - assert context == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_rag_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - # TODO: top_k doesn't affect the output, it should be fixed. - retriever = CompletionRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_get_rag_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = CompletionRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py deleted file mode 100644 index 5f4b93425..000000000 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import pytest -import pathlib - -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.tasks.summarization.models import TextSummary -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever - - -class TestSummariesRetriever: - @pytest.mark.asyncio - async def test_chunk_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk1_summary = TextSummary( - text="S.R.", - made_from=chunk1, - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2_summary = TextSummary( - text="M.B.", - made_from=chunk2, - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3_summary = TextSummary( - text="C.M.", - made_from=chunk3, - ) - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk4_summary = TextSummary( - text="R.R.", - made_from=chunk4, - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5_summary = TextSummary( - text="H.Y.", - made_from=chunk5, - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6_summary = TextSummary( - text="C.H.", - made_from=chunk6, - ) - - entities = [ - chunk1_summary, - chunk2_summary, - chunk3_summary, - chunk4_summary, - chunk5_summary, - chunk6_summary, - ] - - await add_data_points(entities) - - retriever = SummariesRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = SummariesRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) - - context = await retriever.get_context("Christina Mayer") - assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py deleted file mode 100644 index c3c6a47f6..000000000 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ /dev/null @@ -1,224 +0,0 @@ -from types import SimpleNamespace -import pytest - -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever - - -# Test TemporalRetriever initialization defaults and overrides -def test_init_defaults_and_overrides(): - tr = TemporalRetriever() - assert tr.top_k == 5 - assert tr.user_prompt_path == "graph_context_for_question.txt" - assert tr.system_prompt_path == "answer_simple_question.txt" - assert tr.time_extraction_prompt_path == "extract_query_time.txt" - - tr2 = TemporalRetriever( - top_k=3, - user_prompt_path="u.txt", - system_prompt_path="s.txt", - time_extraction_prompt_path="t.txt", - ) - assert tr2.top_k == 3 - assert tr2.user_prompt_path == "u.txt" - assert tr2.system_prompt_path == "s.txt" - assert tr2.time_extraction_prompt_path == "t.txt" - - -# Test descriptions_to_string with basic and empty results -def test_descriptions_to_string_basic_and_empty(): - tr = TemporalRetriever() - - results = [ - {"description": " First "}, - {"nope": "no description"}, - {"description": "Second"}, - {"description": ""}, - {"description": " Third line "}, - ] - - s = tr.descriptions_to_string(results) - assert s == "First\n#####################\nSecond\n#####################\nThird line" - - assert tr.descriptions_to_string([]) == "" - - -# Test filter_top_k_events sorts and limits correctly -@pytest.mark.asyncio -async def test_filter_top_k_events_sorts_and_limits(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3 - not in vector results"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - - assert [e["id"] for e in top] == ["e2", "e1"] - assert all("score" in e for e in top) - assert top[0]["score"] == 0.10 - assert top[1]["score"] == 0.20 - - -# Test filter_top_k_events handles unknown ids as infinite scores -@pytest.mark.asyncio -async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "known1", "description": "Known 1"}, - {"id": "unknown", "description": "Unknown"}, - {"id": "known2", "description": "Known 2"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in top] == ["known2", "known1"] - assert all(e["score"] != float("inf") for e in top) - - -# Test descriptions_to_string with unicode and newlines -def test_descriptions_to_string_unicode_and_newlines(): - tr = TemporalRetriever() - results = [ - {"description": "Line A\nwith newline"}, - {"description": "This is a description"}, - ] - s = tr.descriptions_to_string(results) - assert "Line A\nwith newline" in s - assert "This is a description" in s - assert s.count("#####################") == 1 - - -# Test filter_top_k_events when top_k is larger than available events -@pytest.mark.asyncio -async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): - tr = TemporalRetriever(top_k=10) - relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] - scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), - ] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["a", "b"] - - -# Test filter_top_k_events when scored_results is empty -@pytest.mark.asyncio -async def test_filter_top_k_events_handles_empty_scored_results(): - tr = TemporalRetriever(top_k=2) - relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] - scored_results = [] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["x", "y"] - assert all(e["score"] == float("inf") for e in out) - - -# Test filter_top_k_events error handling for missing structure -@pytest.mark.asyncio -async def test_filter_top_k_events_error_handling(): - tr = TemporalRetriever(top_k=2) - with pytest.raises((KeyError, TypeError)): - await tr.filter_top_k_events([{}], []) - - -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] - - async def extract_time_from_query(self, query: str): - if "both" in query: - return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] - - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" - - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] - - async def _fake_graph_collect_events(self, ids): - return [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, - ] - } - ] - - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] - - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) - - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) - - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) - - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 - ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) - - -# Test get_context fallback to triplets when no time is extracted -@pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" - - -# Test get_context when time is extracted and vector ranking is applied -@pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py deleted file mode 100644 index 3dc9f38d9..000000000 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ /dev/null @@ -1,608 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch - -from cognee.modules.retrieval.utils.brute_force_triplet_search import ( - brute_force_triplet_search, - get_memory_fragment, -) -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError - - -class MockScoredResult: - """Mock class for vector search results.""" - - def __init__(self, id, score, payload=None): - self.id = id - self.score = score - self.payload = payload or {} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_empty_query(): - """Test that empty query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query="") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_none_query(): - """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query=None) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_negative_top_k(): - """Test that negative top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=-1) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_zero_top_k(): - """Test that zero top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=0) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_global_search(): - """Test that wide_search_limit is applied for global search (node_name=None).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=None, # Global search - wide_search_top_k=75, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 75 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): - """Test that wide_search_limit is None for filtered search (node_name provided).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=["Node1"], - wide_search_top_k=50, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] is None - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_default(): - """Test that wide_search_top_k defaults to 100.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", node_name=None) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 100 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_default_collections(): - """Test that default collections are used when none provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test") - - expected_collections = [ - "Entity_name", - "TextSummary_text", - "EntityType_name", - "DocumentChunk_text", - "EdgeType_relationship_name", - ] - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert call_collections == expected_collections - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_custom_collections(): - """Test that custom collections are used when provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - custom_collections = ["CustomCol1", "CustomCol2"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=custom_collections) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_always_includes_edge_collection(): - """Test that EdgeType_relationship_name is always searched even when not in collections.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - collections_without_edge = ["Entity_name", "TextSummary_text"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=collections_without_edge) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert "EdgeType_relationship_name" in call_collections - assert set(call_collections) == set(collections_without_edge) | { - "EdgeType_relationship_name" - } - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_all_collections_empty(): - """Test that empty list is returned when all collections return no results.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - results = await brute_force_triplet_search(query="test") - assert results == [] - - -# Tests for query embedding - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_embeds_query(): - """Test that query is embedded before searching.""" - query_text = "test query" - expected_vector = [0.1, 0.2, 0.3] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query=query_text) - - mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["query_vector"] == expected_vector - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_extracts_node_ids_global_search(): - """Test that node IDs are extracted from search results for global search.""" - scored_results = [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - MockScoredResult("node3", 0.92), - ] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=scored_results) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_reuses_provided_fragment(): - """Test that provided memory fragment is reused instead of creating new one.""" - provided_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" - ) as mock_get_fragment, - ): - await brute_force_triplet_search( - query="test", - memory_fragment=provided_fragment, - node_name=["node"], - ) - - mock_get_fragment.assert_not_called() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): - """Test that memory fragment is created when not provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment, - ): - await brute_force_triplet_search(query="test", node_name=["node"]) - - mock_get_fragment.assert_called_once() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): - """Test that custom top_k is passed to importance calculation.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ), - ): - custom_top_k = 15 - await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( - side_effect=EntityNotFoundError("Entity not found") - ) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_error(): - """Test that get_memory_fragment returns empty graph on generic error.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_deduplicates_node_ids(): - """Test that duplicate node IDs across collections are deduplicated.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ] - elif collection_name == "TextSummary_text": - return [ - MockScoredResult("node1", 0.90), - MockScoredResult("node3", 0.92), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - assert len(call_kwargs["relevant_ids_to_filter"]) == 3 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_excludes_edge_collection(): - """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "EdgeType_relationship_name": - return [MockScoredResult("edge1", 0.88)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search( - query="test", - node_name=None, - collections=["Entity_name", "EdgeType_relationship_name"], - ) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert call_kwargs["relevant_ids_to_filter"] == ["node1"] - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_skips_nodes_without_ids(): - """Test that nodes without ID attribute are skipped.""" - - class ScoredResultNoId: - """Mock result without id attribute.""" - - def __init__(self, score): - self.score = score - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - ScoredResultNoId(0.90), - MockScoredResult("node2", 0.87), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_handles_tuple_results(): - """Test that both list and tuple results are handled correctly.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return ( - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ) - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_mixed_empty_collections(): - """Test ID extraction with mixed empty and non-empty collections.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "TextSummary_text": - return [] - elif collection_name == "EntityType_name": - return [MockScoredResult("node2", 0.92)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py deleted file mode 100644 index d79aca428..000000000 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.triplet_retriever import TripletRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.has_collection = AsyncMock(return_value=True) - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of triplet context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Alice knows Bob"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Bob works at Tech Corp"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = TripletRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_no_collection(mock_vector_engine): - """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" - mock_vector_engine.has_collection.return_value = False - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="create_triplet_embeddings"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty string is returned when no triplets are found.""" - mock_vector_engine.search.return_value = [] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "" - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query")