test: make search unit tests deterministic (#726)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Daniel Molnar <soobrosa@gmail.com>
This commit is contained in:
Boris 2025-04-18 21:55:24 +02:00 committed by GitHub
parent 751eca7aaf
commit 675b66175f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 1127 additions and 824 deletions

View file

@ -8,12 +8,30 @@ on:
type: string
default: '3.11.x'
secrets:
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
LLM_PROVIDER:
required: true
LLM_MODEL:
required: true
LLM_ENDPOINT:
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
required: true
EMBEDDING_API_KEY:
required: true
EMBEDDING_API_VERSION:
required: true
env:
RUNTIME__LOG_LEVEL: ERROR
@ -60,6 +78,18 @@ jobs:
unit-tests:
name: Run Unit Tests
runs-on: ubuntu-22.04
env:
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4
@ -95,10 +125,20 @@ jobs:
name: Run Simple Examples
runs-on: ubuntu-22.04
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4
@ -117,10 +157,20 @@ jobs:
name: Run Basic Graph Tests
runs-on: ubuntu-22.04
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
steps:
- name: Check out repository
uses: actions/checkout@v4

View file

@ -1,4 +1,4 @@
name: Reusable Vector DB Tests
name: Reusable Graph DB Tests
on:
workflow_call:

View file

@ -8,26 +8,30 @@ on:
type: string
default: '["3.10.x", "3.11.x", "3.12.x"]'
secrets:
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
LLM_PROVIDER:
required: true
LLM_MODEL:
required: false
required: true
LLM_ENDPOINT:
required: false
required: true
LLM_API_KEY:
required: true
LLM_API_VERSION:
required: false
required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL:
required: false
required: true
EMBEDDING_ENDPOINT:
required: false
required: true
EMBEDDING_API_KEY:
required: false
required: true
EMBEDDING_API_VERSION:
required: false
required: true
env:
RUNTIME__LOG_LEVEL: ERROR
@ -55,6 +59,18 @@ jobs:
- name: Run unit tests
run: poetry run pytest cognee/tests/unit/
env:
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
- name: Run integration tests
if: ${{ !contains(matrix.os, 'windows') }}
@ -62,13 +78,16 @@ jobs:
- name: Run default basic pipeline
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}

View file

@ -8,6 +8,8 @@ on:
type: string
default: '3.11.x'
secrets:
LLM_PROVIDER:
required: true
LLM_MODEL:
required: true
LLM_ENDPOINT:
@ -16,6 +18,8 @@ on:
required: true
LLM_API_VERSION:
required: true
EMBEDDING_PROVIDER:
required: true
EMBEDDING_MODEL:
required: true
EMBEDDING_ENDPOINT:
@ -24,12 +28,6 @@ on:
required: true
EMBEDDING_API_VERSION:
required: true
OPENAI_API_KEY:
required: true
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
jobs:
run-relational-db-migration-test-networkx:
@ -81,10 +79,13 @@ jobs:
- name: Run relational db test
env:
ENV: 'dev'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -141,10 +142,14 @@ jobs:
env:
ENV: 'dev'
GRAPH_DATABASE_PROVIDER: 'kuzu'
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
@ -204,10 +209,14 @@ jobs:
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
GRAPH_DATABASE_USERNAME: "neo4j"
LLM_PROVIDER: openai
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_PROVIDER: openai
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}

View file

@ -1,28 +0,0 @@
"""Add default user
Revision ID: 482cd6517ce4
Revises: 8057ae7329c2
Create Date: 2024-10-16 22:17:18.634638
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
from cognee.modules.users.methods import create_default_user, delete_user
# revision identifiers, used by Alembic.
revision: str = "482cd6517ce4"
down_revision: Union[str, None] = "8057ae7329c2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
def upgrade() -> None:
await_only(create_default_user())
def downgrade() -> None:
await_only(delete_user("default_user@example.com"))

View file

@ -1,27 +0,0 @@
"""Initial migration
Revision ID: 8057ae7329c2
Revises:
Create Date: 2024-10-02 12:55:20.989372
"""
from typing import Sequence, Union
from sqlalchemy.util import await_only
from cognee.infrastructure.databases.relational import get_relational_engine
# revision identifiers, used by Alembic.
revision: str = "8057ae7329c2"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.create_database())
def downgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.delete_database())

View file

@ -1,8 +1,7 @@
from typing import Union
from cognee.modules.search.types import SearchType
from cognee.modules.users.exceptions import UserNotFoundError
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function
@ -22,9 +21,6 @@ async def search(
if user is None:
user = await get_default_user()
if user is None:
raise UserNotFoundError
filtered_search_results = await search_function(
query_text,
query_type,

View file

@ -10,4 +10,5 @@ from .exceptions import (
ServiceError,
InvalidValueError,
InvalidAttributeError,
CriticalError,
)

View file

@ -53,3 +53,7 @@ class InvalidAttributeError(CogneeApiError):
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class CriticalError(CogneeApiError):
pass

View file

@ -7,4 +7,5 @@ This module defines a set of exceptions for handling various database errors
from .exceptions import (
EntityNotFoundError,
EntityAlreadyExistsError,
DatabaseNotCreatedError,
)

View file

@ -1,5 +1,15 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
from cognee.exceptions import CogneeApiError, CriticalError
class DatabaseNotCreatedError(CriticalError):
def __init__(
self,
message: str = "The database has not been created yet. Please call `await setup()` first.",
name: str = "DatabaseNotCreatedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)
class EntityNotFoundError(CogneeApiError):

View file

@ -1,13 +1,13 @@
from typing import Protocol, Optional, Dict, Any, List, Tuple
from abc import abstractmethod, ABC
from uuid import UUID, uuid5, NAMESPACE_DNS
from cognee.modules.graph.relationship_manager import create_relationship
from functools import wraps
import inspect
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
from cognee.shared.logging_utils import get_logger
from datetime import datetime, timezone
logger = get_logger()
@ -44,20 +44,16 @@ def record_graph_changes(func):
async with db_engine.get_async_session() as session:
if func.__name__ == "add_nodes":
nodes = args[0]
nodes: List[DataPoint] = args[0]
for node in nodes:
try:
node_id = (
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
)
node_id = UUID(str(node.id))
relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=node[1].get("type")
if isinstance(node, tuple)
else type(node).__name__,
node_label=getattr(node, "name", None) or str(node.id),
)
session.add(relationship)
await session.flush()
@ -74,7 +70,7 @@ def record_graph_changes(func):
target_id = UUID(str(edge[1]))
rel_type = str(edge[2])
relationship = GraphRelationshipLedger(
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",

View file

@ -0,0 +1 @@
from .exceptions import CollectionNotFoundError

View file

@ -0,0 +1,12 @@
from fastapi import status
from cognee.exceptions import CriticalError
class CollectionNotFoundError(CriticalError):
def __init__(
self,
message,
name: str = "DatabaseNotCreatedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Generic, List, Optional, TypeVar, get_type_hints
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import lancedb
from lancedb.pydantic import LanceModel, Vector
@ -10,6 +10,7 @@ from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
@ -79,7 +80,6 @@ class LanceDBAdapter(VectorDBInterface):
connection = await self.get_connection()
payload_schema = type(data_points[0])
payload_schema = self.get_data_point_schema(payload_schema)
if not await self.has_collection(collection_name):
await self.create_collection(
@ -194,12 +194,19 @@ class LanceDBAdapter(VectorDBInterface):
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
try:
collection = await connection.open_table(collection_name)
except ValueError:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values())
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
@ -288,11 +295,33 @@ class LanceDBAdapter(VectorDBInterface):
if self.url.startswith("/"):
LocalStorage.remove_all(self.url)
def get_data_point_schema(self, model_type):
def get_data_point_schema(self, model_type: BaseModel):
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():
if hasattr(field_config, "model_fields"):
related_models_fields.append(field_name)
elif hasattr(field_config.annotation, "model_fields"):
related_models_fields.append(field_name)
elif (
get_origin(field_config.annotation) == Union
or get_origin(field_config.annotation) is list
):
models_list = get_args(field_config.annotation)
if any(hasattr(model, "model_fields") for model in models_list):
related_models_fields.append(field_name)
elif get_origin(field_config.annotation) == Optional:
model = get_args(field_config.annotation)
if hasattr(model, "model_fields"):
related_models_fields.append(field_name)
return copy_model(
model_type,
include_fields={
"id": (str, ...),
},
exclude_fields=["metadata"],
exclude_fields=["metadata"] + related_models_fields,
)

View file

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
@ -183,7 +184,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if collection_name in metadata.tables:
return metadata.tables[collection_name]
else:
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database
@ -244,6 +245,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
except EntityNotFoundError:
# Ignore if collection does not exist
return []
except CollectionNotFoundError:
# Ignore if collection does not exist
return []
async def search(
self,

View file

@ -1,8 +1,8 @@
import os
from typing import Optional
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator, Field
import os
from pydantic import model_validator
class LLMConfig(BaseSettings):

View file

@ -1,7 +1,6 @@
from datetime import datetime, timezone
from uuid import uuid5, NAMESPACE_DNS
from uuid import uuid5, NAMESPACE_OID
from sqlalchemy import UUID, Column, DateTime, String, Index
from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base
@ -12,7 +11,7 @@ class GraphRelationshipLedger(Base):
id = Column(
UUID,
primary_key=True,
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
)
source_node_id = Column(UUID, nullable=False)
destination_node_id = Column(UUID, nullable=False)

View file

@ -111,9 +111,6 @@ class CogneeGraph(CogneeAbstractGraph):
except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}")
raise e
except Exception as ex:
print(f"Unexpected error: {ex}")
raise ex
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items():

View file

@ -2,15 +2,28 @@ from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class ChunksRetriever(BaseRetriever):
"""Retriever for handling document chunk-based searches."""
def __init__(
self,
top_k: Optional[int] = 5,
):
self.top_k = top_k
async def get_context(self, query: str) -> Any:
"""Retrieves document chunks context based on the query."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
return [result.payload for result in found_chunks]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -1,9 +1,10 @@
from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class CompletionRetriever(BaseRetriever):
@ -20,15 +21,21 @@ class CompletionRetriever(BaseRetriever):
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
async def get_context(self, query: str) -> Any:
async def get_context(self, query: str) -> str:
"""Retrieves relevant document chunks as context."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
if len(found_chunks) == 0:
raise NoRelevantDataFound
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload)
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
if len(found_chunks) == 0:
return ""
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates an LLM completion using the context."""

View file

@ -1,5 +1,5 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
from cognee.exceptions import CogneeApiError, CriticalError
class CollectionDistancesNotFoundError(CogneeApiError):
@ -30,3 +30,7 @@ class CypherSearchError(CogneeApiError):
status_code: int = status.HTTP_400_BAD_REQUEST,
):
super().__init__(message, name, status_code)
class NoDataError(CriticalError):
message: str = "No data found in the system, please add data first."

View file

@ -3,12 +3,12 @@ from collections import Counter
import string
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class GraphCompletionRetriever(BaseRetriever):
@ -72,14 +72,18 @@ class GraphCompletionRetriever(BaseRetriever):
query, top_k=self.top_k, collections=vector_index_collections or None
)
if len(found_triplets) == 0:
raise NoRelevantDataFound
return found_triplets
async def get_context(self, query: str) -> Any:
async def get_context(self, query: str) -> str:
"""Retrieves and resolves graph triplets into context."""
triplets = await self.get_triplets(query)
try:
triplets = await self.get_triplets(query)
except EntityNotFoundError:
return ""
if len(triplets) == 0:
return ""
return await self.resolve_edges_to_text(triplets)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -4,6 +4,8 @@ from typing import Any, Optional
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class InsightsRetriever(BaseRetriever):
@ -14,7 +16,7 @@ class InsightsRetriever(BaseRetriever):
self.exploration_levels = exploration_levels
self.top_k = top_k
async def get_context(self, query: str) -> Any:
async def get_context(self, query: str) -> list:
"""Find the neighbours of a given node in the graph."""
if query is None:
return []
@ -27,10 +29,15 @@ class InsightsRetriever(BaseRetriever):
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
try:
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]

View file

@ -2,6 +2,8 @@ from typing import Any, Optional
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class SummariesRetriever(BaseRetriever):
@ -14,7 +16,14 @@ class SummariesRetriever(BaseRetriever):
async def get_context(self, query: str) -> Any:
"""Retrieves summary context based on the query."""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.limit
)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
return [summary.payload for summary in summaries_results]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:

View file

@ -82,9 +82,6 @@ async def brute_force_triplet_search(
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(
query,
user,
@ -174,4 +171,4 @@ async def brute_force_search(
send_telemetry(
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
)
raise RuntimeError("An error occurred during brute force search") from error
raise error

View file

@ -20,9 +20,6 @@ async def code_description_to_code_part_search(
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
return retrieved_codeparts

View file

@ -2,9 +2,11 @@ from types import SimpleNamespace
from sqlalchemy.orm import selectinload
from sqlalchemy.future import select
from cognee.modules.users.models import User
from cognee.base_config import get_base_config
from cognee.modules.users.exceptions.exceptions import UserNotFoundError
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.methods.create_default_user import create_default_user
from cognee.base_config import get_base_config
async def get_default_user() -> SimpleNamespace:
@ -12,16 +14,24 @@ async def get_default_user() -> SimpleNamespace:
base_config = get_base_config()
default_email = base_config.default_user_email or "default_user@example.com"
async with db_engine.get_async_session() as session:
query = select(User).options(selectinload(User.roles)).where(User.email == default_email)
try:
async with db_engine.get_async_session() as session:
query = (
select(User).options(selectinload(User.roles)).where(User.email == default_email)
)
result = await session.execute(query)
user = result.scalars().first()
result = await session.execute(query)
user = result.scalars().first()
if user is None:
return await create_default_user()
if user is None:
return await create_default_user()
# We return a SimpleNamespace to have the same user type as our SaaS
# SimpleNamespace is just a dictionary which can be accessed through attributes
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
return auth_data
# We return a SimpleNamespace to have the same user type as our SaaS
# SimpleNamespace is just a dictionary which can be accessed through attributes
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
return auth_data
except Exception as error:
if "principals" in str(error.args):
raise DatabaseNotCreatedError() from error
raise UserNotFoundError(f"Failed to retrieve default user: {default_email}") from error

View file

@ -1 +1 @@
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.tasks.completion.exceptions import NoRelevantDataError

View file

@ -5,5 +5,5 @@ This module defines a set of exceptions for handling various compute errors
"""
from .exceptions import (
NoRelevantDataFound,
NoRelevantDataError,
)

View file

@ -2,11 +2,11 @@ from cognee.exceptions import CogneeApiError
from fastapi import status
class NoRelevantDataFound(CogneeApiError):
class NoRelevantDataError(CogneeApiError):
def __init__(
self,
message: str = "Search did not find any data.",
name: str = "NoRelevantDataFound",
name: str = "NoRelevantDataError",
status_code=status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)

View file

@ -1,8 +1,5 @@
import asyncio
import time
import os
from unittest.mock import patch, MagicMock
from functools import lru_cache
import asyncio
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.rate_limiter import (
sleep_and_retry_sync,

View file

@ -1,120 +1,195 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
class TestChunksRetriever:
@pytest.fixture
def mock_retriever(self):
return ChunksRetriever()
@pytest.mark.asyncio
async def test_chunk_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = ChunksRetriever()
context = await retriever.get_context("Mike")
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query"
doc_id1 = str(uuid.uuid4())
doc_id2 = str(uuid.uuid4())
async def test_chunk_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"text": "This is the first chunk result.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
}
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
"text": "This is the second chunk result.",
"document_id": doc_id2,
"metadata": {"title": "Document 2"},
}
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
# Execute
results = await mock_retriever.get_completion(query)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
# Verify
assert len(results) == 2
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
# Check first result
assert results[0]["text"] == "This is the first chunk result."
assert results[0]["document_id"] == doc_id1
assert results[0]["metadata"]["title"] == "Document 1"
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
# Check second result
assert results[1]["text"] == "This is the second chunk result."
assert results[1]["document_id"] == doc_id2
assert results[1]["metadata"]["title"] == "Document 2"
await add_data_points(entities)
# Verify search was called correctly
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
retriever = ChunksRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
# Execute
results = await mock_retriever.get_completion(query)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Verify
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
retriever = ChunksRetriever()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with incomplete data"
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"text": "This chunk has no document_id.",
# Missing document_id and metadata
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
# Missing text
"document_id": str(uuid.uuid4()),
"metadata": {"title": "Document with missing text"},
}
vector_engine = get_vector_engine()
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 2
if __name__ == "__main__":
from asyncio import run
# First result should have content but no document_id
assert results[0]["text"] == "This chunk has no document_id."
assert "document_id" not in results[0]
assert "metadata" not in results[0]
test = TestChunksRetriever()
# Second result should have document_id and metadata but no content
assert "text" not in results[1]
assert "document_id" in results[1]
assert results[1]["metadata"]["title"] == "Document with missing text"
run(test.test_chunk_context_simple())
run(test.test_chunk_context_complex())
run(test.test_chunk_context_on_empty_graph())

View file

@ -1,84 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return CompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion(
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
):
# Setup
query = "test query"
# Mock render_prompt
mock_render_prompt.return_value = "Rendered prompt with context"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Mock LLM client
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Generated completion response"
# Verify prompt was rendered
mock_render_prompt.assert_called_once()
# Verify LLM client was called
mock_llm_client.acreate_structured_output.assert_called_once_with(
text_input="Rendered prompt with context", system_prompt=None, response_model=str
)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
async def test_get_completion_with_custom_prompt(
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
):
# Setup
query = "test query with custom prompt"
mock_search_results = [MagicMock()]
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_generate_completion.return_value = "Custom prompt completion response"
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0] == "Custom prompt completion response"
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
assert (
mock_generate_completion.call_args[1]["system_prompt_path"]
== "custom_system_prompt.txt"
)

View file

@ -1,236 +1,159 @@
from unittest.mock import AsyncMock, MagicMock, patch
import os
import pytest
import pathlib
from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.graph.exceptions import EntityNotFoundError
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class TestGraphCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever):
mock_brute_force_triplet_search.return_value = [
AsyncMock(
node1=AsyncMock(attributes={"text": "Node A"}),
attributes={"relationship_type": "connects"},
node2=AsyncMock(attributes={"text": "Node B"}),
)
]
result = await mock_retriever.get_triplets("test query")
assert isinstance(result, list)
assert len(result) > 0
assert result[0].attributes["relationship_type"] == "connects"
mock_brute_force_triplet_search.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
mock_brute_force_triplet_search.return_value = []
with pytest.raises(NoRelevantDataFound):
await mock_retriever.get_triplets("test query")
@pytest.mark.asyncio
async def test_resolve_edges_to_text(self, mock_retriever):
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
triplets = [
AsyncMock(
node1=node_a,
attributes={"relationship_type": "connects"},
node2=node_b,
),
AsyncMock(
node1=node_a,
attributes={"relationship_type": "links"},
node2=node_c,
),
]
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
result = await mock_retriever.resolve_edges_to_text(triplets)
assert "Nodes:" in result
assert "Connections:" in result
assert "Node: Test Title" in result
assert "__node_content_start__" in result
assert "Node A text content" in result
assert "__node_content_end__" in result
assert "Node: Node C" in result
assert "Test Title --[connects]--> Test Title" in result
assert "Test Title --[links]--> Node C" in result
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
new_callable=AsyncMock,
)
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
new_callable=AsyncMock,
)
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
"""Test get_context calls get_triplets and resolve_edges_to_text."""
mock_get_triplets.return_value = ["mock_triplet"]
mock_resolve_edges_to_text.return_value = "Mock Context"
result = await mock_retriever.get_context("test query")
assert result == "Mock Context"
mock_get_triplets.assert_called_once_with("test query")
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_without_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when no context is provided (calls get_context)."""
mock_get_context.return_value = "Mock Context"
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query")
assert result == ["Generated Completion"]
mock_get_context.assert_called_once_with("test query")
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
)
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
async def test_get_completion_with_context(
self, mock_generate_completion, mock_get_context, mock_retriever
):
"""Test get_completion when context is provided (does not call get_context)."""
mock_generate_completion.return_value = "Generated Completion"
result = await mock_retriever.get_completion("test query", context="Provided Context")
assert result == ["Generated Completion"]
mock_get_context.assert_not_called()
mock_generate_completion.assert_called_once()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_empty_graph(
self,
mock_get_default_user,
mock_get_graph_engine,
mock_get_llm_client,
mock_retriever,
):
query = "test query with empty graph"
mock_graph_engine = MagicMock()
mock_graph_engine.get_graph_data = AsyncMock()
mock_graph_engine.get_graph_data.return_value = ([], [])
mock_get_graph_engine.return_value = mock_graph_engine
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph completion response"
async def test_graph_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
)
mock_get_llm_client.return_value = mock_llm_client
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
)
cognee.config.data_root_directory(data_directory_path)
with pytest.raises(EntityNotFoundError):
await mock_retriever.get_completion(query)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
mock_graph_engine.get_graph_data.assert_called_once()
class Company(DataPoint):
name: str
def test_top_n_words(self, mock_retriever):
"""Test extraction of top frequent words from text."""
text = "The quick brown fox jumps over the lazy dog. The fox is quick."
class Person(DataPoint):
name: str
works_for: Company
result = mock_retriever._top_n_words(text)
assert len(result.split(", ")) <= 3
assert "fox" in result
assert "quick" in result
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)
result = mock_retriever._top_n_words(text, top_n=2)
assert len(result.split(", ")) <= 2
entities = [company1, company2, person1, person2, person3, person4, person5]
result = mock_retriever._top_n_words(text, separator=" | ")
assert " | " in result
await add_data_points(entities)
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"})
assert "fox" not in result
assert "quick" not in result
retriever = GraphCompletionRetriever()
def test_get_title(self, mock_retriever):
"""Test title generation from text."""
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
context = await retriever.get_context("Who works at Canva?")
title = mock_retriever._get_title(text)
assert "..." in title
assert "[" in title and "]" in title
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
title = mock_retriever._get_title(text, first_n_words=3)
first_part = title.split("...")[0].strip()
assert len(first_part.split()) == 3
@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
title = mock_retriever._get_title(text, top_n_words=2)
top_part = title.split("[")[1].split("]")[0]
assert len(top_part.split(", ")) <= 2
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
def test_get_nodes(self, mock_retriever):
"""Test node processing and deduplication."""
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"})
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
node_without_attrs = AsyncMock(id="empty_node", attributes={})
class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
edges = [
AsyncMock(
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"}
),
AsyncMock(
node1=node_with_text,
node2=node_without_attrs,
attributes={"relationship_type": "rel2"},
),
AsyncMock(
node1=node_with_name,
node2=node_without_attrs,
attributes={"relationship_type": "rel3"},
),
class Car(DataPoint):
brand: str
model: str
year: int
class Location(DataPoint):
country: str
city: str
class Home(DataPoint):
location: Location
rooms: int
sqm: int
class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None
company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"):
nodes = mock_retriever._get_nodes(edges)
person3 = Person(name="Jason Statham", works_for=company1)
assert len(nodes) == 3
person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
for node_id, info in nodes.items():
assert "node" in info
assert "name" in info
assert "content" in info
person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
text_node_info = nodes[node_with_text.id]
assert text_node_info["name"] == "Generated Title"
assert text_node_info["content"] == "This is a text node"
entities = [company1, company2, person1, person2, person3, person4, person5]
name_node_info = nodes[node_with_name.id]
assert name_node_info["name"] == "Named Node"
assert name_node_info["content"] == "Named Node"
await add_data_points(entities)
empty_node_info = nodes[node_without_attrs.id]
assert empty_node_info["name"] == "Unnamed Node"
retriever = GraphCompletionRetriever(top_k=20)
context = await retriever.get_context("Who works at Figma?")
print(context)
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = GraphCompletionRetriever()
with pytest.raises(DatabaseNotCreatedError):
await retriever.get_context("Who works at Figma?")
await setup()
context = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestGraphCompletionRetriever()
run(test.test_graph_completion_context_simple())
run(test.test_graph_completion_context_complex())
run(test.test_get_graph_completion_context_on_empty_graph())

View file

@ -1,80 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
class TestGraphSummaryCompletionRetriever:
@pytest.fixture
def mock_retriever(self):
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
async def test_get_completion_with_custom_system_prompt(
self,
mock_get_default_user,
mock_render_prompt,
mock_read_query_prompt,
mock_get_llm_client,
mock_retriever,
):
# Setup
query = "test query with custom prompt"
# Set custom system prompt
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
"Generated graph summary completion response"
)
mock_get_llm_client.return_value = mock_llm_client
# Execute
results = await mock_retriever.get_completion(query, context="test context")
# Verify
assert len(results) == 1
# Verify render_prompt was called with custom prompt path
mock_render_prompt.assert_called_once()
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
mock_read_query_prompt.assert_called_once()
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
mock_llm_client.acreate_structured_output.assert_called_once()
@pytest.mark.asyncio
@patch(
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
)
@patch(
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
new_callable=AsyncMock,
)
async def test_resolve_edges_to_text_calls_super_and_summarizes(
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
):
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
mock_summarize_text.return_value = "Summarized graph text"
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
mock_summarize_text.assert_called_once_with(
"Raw graph edges text", mock_retriever.summarize_prompt_path
)
assert result == "Summarized graph text"

View file

@ -1,103 +1,216 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import os
import pytest
import pathlib
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
import unittest
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.modules.engine.models import Entity, EntityType
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
class TestInsightsRetriever:
@pytest.fixture
def mock_retriever(self):
return InsightsRetriever()
@pytest.mark.asyncio
async def test_insights_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
entityTypePerson = EntityType(
name="Person",
description="An individual",
)
person1 = Entity(
name="Steve Rodger",
is_a=entityTypePerson,
description="An American actor, comedian, and filmmaker",
)
person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
retriever = InsightsRetriever()
context = await retriever.get_context("Mike")
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever):
"""Test get_context when node exists in graph."""
mock_graph = AsyncMock()
mock_get_graph_engine.return_value = mock_graph
async def test_insights_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
# Mock graph response
mock_graph.extract_node.return_value = {"id": "123"}
mock_graph.get_connections.return_value = [
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
]
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
result = await mock_retriever.get_context("123")
entityTypePerson = EntityType(
name="Person",
description="An individual",
)
assert isinstance(result, list)
assert len(result) == 1
assert result[0][0]["id"] == "123"
assert result[0][1]["relationship_name"] == "linked_to"
assert result[0][2]["id"] == "456"
mock_graph.extract_node.assert_called_once_with("123")
mock_graph.get_connections.assert_called_once_with("123")
person1 = Entity(
name="Steve Rodger",
is_a=entityTypePerson,
description="An American actor, comedian, and filmmaker",
)
person2 = Entity(
name="Mike Broski",
is_a=entityTypePerson,
description="Financial advisor and philanthropist",
)
person3 = Entity(
name="Christina Mayer",
is_a=entityTypePerson,
description="Maker of next generation of iconic American music videos",
)
person4 = Entity(
name="Jason Statham",
is_a=entityTypePerson,
description="An American actor",
)
person5 = Entity(
name="Mike Tyson",
is_a=entityTypePerson,
description="A former professional boxer from the United States",
)
entityTypeCompany = EntityType(
name="Company",
description="An organization that operates on an annual basis",
)
company1 = Entity(
name="Apple",
is_a=entityTypeCompany,
description="An American multinational technology company headquartered in Cupertino, California",
)
company2 = Entity(
name="Google",
is_a=entityTypeCompany,
description="An American multinational technology company that specializes in Internet-related services and products",
)
company3 = Entity(
name="Facebook",
is_a=entityTypeCompany,
description="An American social media, messaging, and online platform",
)
entities = [person1, person2, person3, company1, company2, company3]
await add_data_points(entities)
graph_engine = await get_graph_engine()
await graph_engine.add_edges(
[
(person1.id, company1.id, "works_for"),
(person2.id, company2.id, "works_for"),
(person3.id, company3.id, "works_for"),
(person4.id, company1.id, "works_for"),
(person5.id, company1.id, "works_for"),
]
)
retriever = InsightsRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
async def test_insights_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
)
cognee.config.data_root_directory(data_directory_path)
# Execute
results = await mock_retriever.get_completion(query)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Verify
assert len(results) == 0
retriever = InsightsRetriever()
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
async def test_get_context_with_no_exact_node(
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
):
"""Test get_context when node does not exist in the graph and vector search is used."""
mock_graph = AsyncMock()
mock_get_graph_engine.return_value = mock_graph
mock_graph.extract_node.return_value = None # Node does not exist
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
mock_vector = AsyncMock()
mock_get_vector_engine.return_value = mock_vector
vector_engine = get_vector_engine()
await vector_engine.create_collection("Entity_name", payload_schema=Entity)
await vector_engine.create_collection("EntityType_name", payload_schema=EntityType)
mock_vector.search.side_effect = [
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
]
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"
mock_graph.get_connections.side_effect = lambda node_id: [
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
]
result = await mock_retriever.get_context("non_existing_query")
if __name__ == "__main__":
from asyncio import run
assert isinstance(result, list)
assert len(result) == 2
assert result[0][0]["id"] == "vec_1"
assert result[0][1]["relationship_name"] == "related_to"
assert result[0][2]["id"] == "456"
test = TestInsightsRetriever()
assert result[1][0]["id"] == "vec_2"
assert result[1][1]["relationship_name"] == "related_to"
assert result[1][2]["id"] == "456"
@pytest.mark.asyncio
async def test_get_context_with_none_query(self, mock_retriever):
"""Test get_context with a None query (should return empty list)."""
result = await mock_retriever.get_context(None)
assert result == []
@pytest.mark.asyncio
async def test_get_completion_with_context(self, mock_retriever):
"""Test get_completion when context is already provided."""
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
result = await mock_retriever.get_completion("test_query", context=test_context)
assert result == test_context
run(test.test_insights_context_simple())
run(test.test_insights_context_complex())
run(test.test_insights_context_on_empty_graph())

View file

@ -0,0 +1,196 @@
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document = TextDocument(
name="Steve Rodger's career",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document,
contains=[],
)
entities = [chunk1, chunk2, chunk3]
await add_data_points(entities)
retriever = CompletionRetriever()
context = await retriever.get_context("Mike")
assert context == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
await add_data_points(entities)
# TODO: top_k doesn't affect the output, it should be fixed.
retriever = CompletionRetriever(top_k=20)
context = await retriever.get_context("Christina")
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
retriever = CompletionRetriever()
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
vector_engine = get_vector_engine()
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestRAGCompletionRetriever()
run(test.test_rag_completion_context_simple())
run(test.test_rag_completion_context_complex())
run(test.test_get_rag_completion_context_on_empty_graph())

View file

@ -1,122 +1,168 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import os
import pytest
import pathlib
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking.models import DocumentChunk
from cognee.tasks.summarization.models import TextSummary
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
class TestSummariesRetriever:
@pytest.fixture
def mock_retriever(self):
return SummariesRetriever()
class TextSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
document1 = TextDocument(
name="Employee List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
document2 = TextDocument(
name="Car List",
raw_data_location="somewhere",
external_metadata="",
mime_type="text/plain",
)
chunk1 = DocumentChunk(
text="Steve Rodger",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk1_summary = TextSummary(
text="S.R.",
made_from=chunk1,
)
chunk2 = DocumentChunk(
text="Mike Broski",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk2_summary = TextSummary(
text="M.B.",
made_from=chunk2,
)
chunk3 = DocumentChunk(
text="Christina Mayer",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document1,
contains=[],
)
chunk3_summary = TextSummary(
text="C.M.",
made_from=chunk3,
)
chunk4 = DocumentChunk(
text="Range Rover",
chunk_size=2,
chunk_index=0,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk4_summary = TextSummary(
text="R.R.",
made_from=chunk4,
)
chunk5 = DocumentChunk(
text="Hyundai",
chunk_size=2,
chunk_index=1,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk5_summary = TextSummary(
text="H.Y.",
made_from=chunk5,
)
chunk6 = DocumentChunk(
text="Chrysler",
chunk_size=2,
chunk_index=2,
cut_type="sentence_end",
is_part_of=document2,
contains=[],
)
chunk6_summary = TextSummary(
text="C.H.",
made_from=chunk6,
)
entities = [
chunk1_summary,
chunk2_summary,
chunk3_summary,
chunk4_summary,
chunk5_summary,
chunk6_summary,
]
await add_data_points(entities)
retriever = SummariesRetriever(limit=20)
context = await retriever.get_context("Christina")
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query"
doc_id1 = str(uuid.uuid4())
doc_id2 = str(uuid.uuid4())
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
)
cognee.config.data_root_directory(data_directory_path)
# Mock search results
mock_result_1 = MagicMock()
mock_result_1.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is the first summary.",
"document_id": doc_id1,
"metadata": {"title": "Document 1"},
},
}
mock_result_2 = MagicMock()
mock_result_2.payload = {
"id": str(uuid.uuid4()),
"score": 0.85,
"payload": {
"text": "This is the second summary.",
"document_id": doc_id2,
"metadata": {"title": "Document 2"},
},
}
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
mock_search_results = [mock_result_1, mock_result_2]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
retriever = SummariesRetriever()
# Execute
results = await mock_retriever.get_completion(query)
with pytest.raises(NoDataError):
await retriever.get_context("Christina Mayer")
# Verify
assert len(results) == 2
vector_engine = get_vector_engine()
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
# Check first result
assert results[0]["payload"]["text"] == "This is the first summary."
assert results[0]["payload"]["document_id"] == doc_id1
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
assert results[0]["score"] == 0.95
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"
# Check second result
assert results[1]["payload"]["text"] == "This is the second summary."
assert results[1]["payload"]["document_id"] == doc_id2
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
assert results[1]["score"] == 0.85
# Verify search was called correctly
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
if __name__ == "__main__":
from asyncio import run
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with no results"
mock_search_results = []
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
test = TextSummariesRetriever()
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 0
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
# Setup
query = "test query with custom limit"
doc_id = str(uuid.uuid4())
# Mock search results
mock_result = MagicMock()
mock_result.payload = {
"id": str(uuid.uuid4()),
"score": 0.95,
"payload": {
"text": "This is a summary.",
"document_id": doc_id,
"metadata": {"title": "Document 1"},
},
}
mock_search_results = [mock_result]
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = mock_search_results
mock_get_vector_engine.return_value = mock_vector_engine
# Set custom limit
mock_retriever.limit = 10
# Execute
results = await mock_retriever.get_completion(query)
# Verify
assert len(results) == 1
assert results[0]["payload"]["text"] == "This is a summary."
# Verify search was called with custom limit
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)
run(test.test_chunk_context())
run(test.test_chunk_context_on_empty_graph())

View file

@ -1,11 +1,11 @@
import pytest
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
from unittest.mock import AsyncMock, patch
from cognee.modules.users.models import User
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_search,
brute_force_triplet_search,
)
from unittest.mock import AsyncMock, patch
@pytest.mark.asyncio
@ -20,13 +20,11 @@ async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info:
with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
@ -40,9 +38,7 @@ async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_e
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info:
with pytest.raises(CollectionDistancesNotFoundError):
await brute_force_triplet_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)