test: make search unit tests deterministic (#726)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Daniel Molnar <soobrosa@gmail.com>
This commit is contained in:
parent
751eca7aaf
commit
675b66175f
40 changed files with 1127 additions and 824 deletions
62
.github/workflows/basic_tests.yml
vendored
62
.github/workflows/basic_tests.yml
vendored
|
|
@ -8,12 +8,30 @@ on:
|
|||
type: string
|
||||
default: '3.11.x'
|
||||
secrets:
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
GRAPHISTRY_USERNAME:
|
||||
required: true
|
||||
GRAPHISTRY_PASSWORD:
|
||||
required: true
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_PROVIDER:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
|
|
@ -60,6 +78,18 @@ jobs:
|
|||
unit-tests:
|
||||
name: Run Unit Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -95,10 +125,20 @@ jobs:
|
|||
name: Run Simple Examples
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -117,10 +157,20 @@ jobs:
|
|||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
|||
2
.github/workflows/graph_db_tests.yml
vendored
2
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -1,4 +1,4 @@
|
|||
name: Reusable Vector DB Tests
|
||||
name: Reusable Graph DB Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
|
|
|||
41
.github/workflows/python_version_tests.yml
vendored
41
.github/workflows/python_version_tests.yml
vendored
|
|
@ -8,26 +8,30 @@ on:
|
|||
type: string
|
||||
default: '["3.10.x", "3.11.x", "3.12.x"]'
|
||||
secrets:
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
GRAPHISTRY_USERNAME:
|
||||
required: true
|
||||
GRAPHISTRY_PASSWORD:
|
||||
required: true
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
LLM_MODEL:
|
||||
required: false
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
required: false
|
||||
required: true
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: false
|
||||
required: true
|
||||
EMBEDDING_PROVIDER:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: false
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
required: false
|
||||
required: true
|
||||
EMBEDDING_API_KEY:
|
||||
required: false
|
||||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
|
|
@ -55,6 +59,18 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
env:
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
|
||||
- name: Run integration tests
|
||||
if: ${{ !contains(matrix.os, 'windows') }}
|
||||
|
|
@ -62,13 +78,16 @@ jobs:
|
|||
|
||||
- name: Run default basic pipeline
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ on:
|
|||
type: string
|
||||
default: '3.11.x'
|
||||
secrets:
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
LLM_MODEL:
|
||||
required: true
|
||||
LLM_ENDPOINT:
|
||||
|
|
@ -16,6 +18,8 @@ on:
|
|||
required: true
|
||||
LLM_API_VERSION:
|
||||
required: true
|
||||
EMBEDDING_PROVIDER:
|
||||
required: true
|
||||
EMBEDDING_MODEL:
|
||||
required: true
|
||||
EMBEDDING_ENDPOINT:
|
||||
|
|
@ -24,12 +28,6 @@ on:
|
|||
required: true
|
||||
EMBEDDING_API_VERSION:
|
||||
required: true
|
||||
OPENAI_API_KEY:
|
||||
required: true
|
||||
GRAPHISTRY_USERNAME:
|
||||
required: true
|
||||
GRAPHISTRY_PASSWORD:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
run-relational-db-migration-test-networkx:
|
||||
|
|
@ -81,10 +79,13 @@ jobs:
|
|||
- name: Run relational db test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
@ -141,10 +142,14 @@ jobs:
|
|||
env:
|
||||
ENV: 'dev'
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
@ -204,10 +209,14 @@ jobs:
|
|||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
||||
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
"""Add default user
|
||||
|
||||
Revision ID: 482cd6517ce4
|
||||
Revises: 8057ae7329c2
|
||||
Create Date: 2024-10-16 22:17:18.634638
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
from cognee.modules.users.methods import create_default_user, delete_user
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "482cd6517ce4"
|
||||
down_revision: Union[str, None] = "8057ae7329c2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
await_only(create_default_user())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
await_only(delete_user("default_user@example.com"))
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
"""Initial migration
|
||||
|
||||
Revision ID: 8057ae7329c2
|
||||
Revises:
|
||||
Create Date: 2024-10-02 12:55:20.989372
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from sqlalchemy.util import await_only
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "8057ae7329c2"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
await_only(db_engine.create_database())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
await_only(db_engine.delete_database())
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Union
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
|
||||
|
|
@ -22,9 +21,6 @@ async def search(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise UserNotFoundError
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text,
|
||||
query_type,
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ from .exceptions import (
|
|||
ServiceError,
|
||||
InvalidValueError,
|
||||
InvalidAttributeError,
|
||||
CriticalError,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,3 +53,7 @@ class InvalidAttributeError(CogneeApiError):
|
|||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class CriticalError(CogneeApiError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@ This module defines a set of exceptions for handling various database errors
|
|||
from .exceptions import (
|
||||
EntityNotFoundError,
|
||||
EntityAlreadyExistsError,
|
||||
DatabaseNotCreatedError,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,15 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
from cognee.exceptions import CogneeApiError, CriticalError
|
||||
|
||||
|
||||
class DatabaseNotCreatedError(CriticalError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The database has not been created yet. Please call `await setup()` first.",
|
||||
name: str = "DatabaseNotCreatedError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class EntityNotFoundError(CogneeApiError):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Protocol, Optional, Dict, Any, List, Tuple
|
||||
from abc import abstractmethod, ABC
|
||||
from uuid import UUID, uuid5, NAMESPACE_DNS
|
||||
from cognee.modules.graph.relationship_manager import create_relationship
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -44,20 +44,16 @@ def record_graph_changes(func):
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
if func.__name__ == "add_nodes":
|
||||
nodes = args[0]
|
||||
nodes: List[DataPoint] = args[0]
|
||||
for node in nodes:
|
||||
try:
|
||||
node_id = (
|
||||
UUID(str(node[0])) if isinstance(node, tuple) else UUID(str(node.id))
|
||||
)
|
||||
node_id = UUID(str(node.id))
|
||||
relationship = GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=node_id,
|
||||
destination_node_id=node_id,
|
||||
creator_function=f"{creator}.node",
|
||||
node_label=node[1].get("type")
|
||||
if isinstance(node, tuple)
|
||||
else type(node).__name__,
|
||||
node_label=getattr(node, "name", None) or str(node.id),
|
||||
)
|
||||
session.add(relationship)
|
||||
await session.flush()
|
||||
|
|
@ -74,7 +70,7 @@ def record_graph_changes(func):
|
|||
target_id = UUID(str(edge[1]))
|
||||
rel_type = str(edge[2])
|
||||
relationship = GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=source_id,
|
||||
destination_node_id=target_id,
|
||||
creator_function=f"{creator}.{rel_type}",
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .exceptions import CollectionNotFoundError
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from fastapi import status
|
||||
from cognee.exceptions import CriticalError
|
||||
|
||||
|
||||
class CollectionNotFoundError(CriticalError):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
name: str = "DatabaseNotCreatedError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Generic, List, Optional, TypeVar, get_type_hints
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -79,7 +80,6 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
connection = await self.get_connection()
|
||||
|
||||
payload_schema = type(data_points[0])
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
|
|
@ -194,12 +194,19 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
try:
|
||||
collection = await connection.open_table(collection_name)
|
||||
except ValueError:
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
if not result_values:
|
||||
return []
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
|
|
@ -288,11 +295,33 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if self.url.startswith("/"):
|
||||
LocalStorage.remove_all(self.url)
|
||||
|
||||
def get_data_point_schema(self, model_type):
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
if hasattr(field_config, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif hasattr(field_config.annotation, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif (
|
||||
get_origin(field_config.annotation) == Union
|
||||
or get_origin(field_config.annotation) is list
|
||||
):
|
||||
models_list = get_args(field_config.annotation)
|
||||
if any(hasattr(model, "model_fields") for model in models_list):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif get_origin(field_config.annotation) == Optional:
|
||||
model = get_args(field_config.annotation)
|
||||
if hasattr(model, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
return copy_model(
|
||||
model_type,
|
||||
include_fields={
|
||||
"id": (str, ...),
|
||||
},
|
||||
exclude_fields=["metadata"],
|
||||
exclude_fields=["metadata"] + related_models_fields,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -183,7 +184,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if collection_name in metadata.tables:
|
||||
return metadata.tables[collection_name]
|
||||
else:
|
||||
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
# Get PGVectorDataPoint Table from database
|
||||
|
|
@ -244,6 +245,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
except EntityNotFoundError:
|
||||
# Ignore if collection does not exist
|
||||
return []
|
||||
except CollectionNotFoundError:
|
||||
# Ignore if collection does not exist
|
||||
return []
|
||||
|
||||
async def search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import model_validator, Field
|
||||
import os
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class LLMConfig(BaseSettings):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import uuid5, NAMESPACE_DNS
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from sqlalchemy import UUID, Column, DateTime, String, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
|
@ -12,7 +11,7 @@ class GraphRelationshipLedger(Base):
|
|||
id = Column(
|
||||
UUID,
|
||||
primary_key=True,
|
||||
default=lambda: uuid5(NAMESPACE_DNS, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
)
|
||||
source_node_id = Column(UUID, nullable=False)
|
||||
destination_node_id = Column(UUID, nullable=False)
|
||||
|
|
|
|||
|
|
@ -111,9 +111,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
except (ValueError, TypeError) as e:
|
||||
print(f"Error projecting graph: {e}")
|
||||
raise e
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error: {ex}")
|
||||
raise ex
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
for category, scored_results in node_distances.items():
|
||||
|
|
|
|||
|
|
@ -2,15 +2,28 @@ from typing import Any, Optional
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class ChunksRetriever(BaseRetriever):
|
||||
"""Retriever for handling document chunk-based searches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: Optional[int] = 5,
|
||||
):
|
||||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves document chunks context based on the query."""
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
except CollectionNotFoundError as error:
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [result.payload for result in found_chunks]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class CompletionRetriever(BaseRetriever):
|
||||
|
|
@ -20,15 +21,21 @@ class CompletionRetriever(BaseRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Retrieves relevant document chunks as context."""
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
return "\n".join(chunks_payload)
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
|
||||
if len(found_chunks) == 0:
|
||||
return ""
|
||||
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
return "\n".join(chunks_payload)
|
||||
except CollectionNotFoundError as error:
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates an LLM completion using the context."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
from cognee.exceptions import CogneeApiError, CriticalError
|
||||
|
||||
|
||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||
|
|
@ -30,3 +30,7 @@ class CypherSearchError(CogneeApiError):
|
|||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class NoDataError(CriticalError):
|
||||
message: str = "No data found in the system, please add data first."
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from collections import Counter
|
|||
import string
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
@ -72,14 +72,18 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
raise NoRelevantDataFound
|
||||
|
||||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Retrieves and resolves graph triplets into context."""
|
||||
triplets = await self.get_triplets(query)
|
||||
try:
|
||||
triplets = await self.get_triplets(query)
|
||||
except EntityNotFoundError:
|
||||
return ""
|
||||
|
||||
if len(triplets) == 0:
|
||||
return ""
|
||||
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from typing import Any, Optional
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
|
|
@ -14,7 +16,7 @@ class InsightsRetriever(BaseRetriever):
|
|||
self.exploration_levels = exploration_levels
|
||||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
async def get_context(self, query: str) -> list:
|
||||
"""Find the neighbours of a given node in the graph."""
|
||||
if query is None:
|
||||
return []
|
||||
|
|
@ -27,10 +29,15 @@ class InsightsRetriever(BaseRetriever):
|
|||
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
||||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
||||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
||||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
except CollectionNotFoundError as error:
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from typing import Any, Optional
|
|||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class SummariesRetriever(BaseRetriever):
|
||||
|
|
@ -14,7 +16,14 @@ class SummariesRetriever(BaseRetriever):
|
|||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves summary context based on the query."""
|
||||
vector_engine = get_vector_engine()
|
||||
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
|
||||
|
||||
try:
|
||||
summaries_results = await vector_engine.search(
|
||||
"TextSummary_text", query, limit=self.limit
|
||||
)
|
||||
except CollectionNotFoundError as error:
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [summary.payload for summary in summaries_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
|
|
|
|||
|
|
@ -82,9 +82,6 @@ async def brute_force_triplet_search(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
retrieved_results = await brute_force_search(
|
||||
query,
|
||||
user,
|
||||
|
|
@ -174,4 +171,4 @@ async def brute_force_search(
|
|||
send_telemetry(
|
||||
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
|
||||
)
|
||||
raise RuntimeError("An error occurred during brute force search") from error
|
||||
raise error
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ async def code_description_to_code_part_search(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
|
||||
return retrieved_codeparts
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ from types import SimpleNamespace
|
|||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.future import select
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.users.exceptions.exceptions import UserNotFoundError
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.methods.create_default_user import create_default_user
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
|
||||
async def get_default_user() -> SimpleNamespace:
|
||||
|
|
@ -12,16 +14,24 @@ async def get_default_user() -> SimpleNamespace:
|
|||
base_config = get_base_config()
|
||||
default_email = base_config.default_user_email or "default_user@example.com"
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = select(User).options(selectinload(User.roles)).where(User.email == default_email)
|
||||
try:
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = (
|
||||
select(User).options(selectinload(User.roles)).where(User.email == default_email)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
user = result.scalars().first()
|
||||
result = await session.execute(query)
|
||||
user = result.scalars().first()
|
||||
|
||||
if user is None:
|
||||
return await create_default_user()
|
||||
if user is None:
|
||||
return await create_default_user()
|
||||
|
||||
# We return a SimpleNamespace to have the same user type as our SaaS
|
||||
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
||||
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
|
||||
return auth_data
|
||||
# We return a SimpleNamespace to have the same user type as our SaaS
|
||||
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
||||
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
|
||||
return auth_data
|
||||
except Exception as error:
|
||||
if "principals" in str(error.args):
|
||||
raise DatabaseNotCreatedError() from error
|
||||
|
||||
raise UserNotFoundError(f"Failed to retrieve default user: {default_email}") from error
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataError
|
||||
|
|
|
|||
|
|
@ -5,5 +5,5 @@ This module defines a set of exceptions for handling various compute errors
|
|||
"""
|
||||
|
||||
from .exceptions import (
|
||||
NoRelevantDataFound,
|
||||
NoRelevantDataError,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from cognee.exceptions import CogneeApiError
|
|||
from fastapi import status
|
||||
|
||||
|
||||
class NoRelevantDataFound(CogneeApiError):
|
||||
class NoRelevantDataError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Search did not find any data.",
|
||||
name: str = "NoRelevantDataFound",
|
||||
name: str = "NoRelevantDataError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import asyncio
|
||||
import time
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
from functools import lru_cache
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
sleep_and_retry_sync,
|
||||
|
|
|
|||
|
|
@ -1,120 +1,195 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
|
||||
|
||||
class TestChunksRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return ChunksRetriever()
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query"
|
||||
doc_id1 = str(uuid.uuid4())
|
||||
doc_id2 = str(uuid.uuid4())
|
||||
async def test_chunk_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This is the first chunk result.",
|
||||
"document_id": doc_id1,
|
||||
"metadata": {"title": "Document 1"},
|
||||
}
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This is the second chunk result.",
|
||||
"document_id": doc_id2,
|
||||
"metadata": {"title": "Document 2"},
|
||||
}
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
# Check first result
|
||||
assert results[0]["text"] == "This is the first chunk result."
|
||||
assert results[0]["document_id"] == doc_id1
|
||||
assert results[0]["metadata"]["title"] == "Document 1"
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
# Check second result
|
||||
assert results[1]["text"] == "This is the second chunk result."
|
||||
assert results[1]["document_id"] == doc_id2
|
||||
assert results[1]["metadata"]["title"] == "Document 2"
|
||||
await add_data_points(entities)
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
||||
retriever = ChunksRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
mock_vector_engine.search.assert_called_once_with("DocumentChunk_text", query, limit=5)
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.chunks_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_missing_fields(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with incomplete data"
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": "This chunk has no document_id.",
|
||||
# Missing document_id and metadata
|
||||
}
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
# Missing text
|
||||
"document_id": str(uuid.uuid4()),
|
||||
"metadata": {"title": "Document with missing text"},
|
||||
}
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert len(context) == 0, "Found chunks when none should exist"
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
# First result should have content but no document_id
|
||||
assert results[0]["text"] == "This chunk has no document_id."
|
||||
assert "document_id" not in results[0]
|
||||
assert "metadata" not in results[0]
|
||||
test = TestChunksRetriever()
|
||||
|
||||
# Second result should have document_id and metadata but no content
|
||||
assert "text" not in results[1]
|
||||
assert "document_id" in results[1]
|
||||
assert results[1]["metadata"]["title"] == "Document with missing text"
|
||||
run(test.test_chunk_context_simple())
|
||||
run(test.test_chunk_context_complex())
|
||||
run(test.test_chunk_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -1,84 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
|
||||
|
||||
class TestCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return CompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
||||
async def test_get_completion(
|
||||
self, mock_get_vector_engine, mock_render_prompt, mock_get_llm_client, mock_retriever
|
||||
):
|
||||
# Setup
|
||||
query = "test query"
|
||||
|
||||
# Mock render_prompt
|
||||
mock_render_prompt.return_value = "Rendered prompt with context"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Mock LLM client
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = "Generated completion response"
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0] == "Generated completion response"
|
||||
|
||||
# Verify prompt was rendered
|
||||
mock_render_prompt.assert_called_once()
|
||||
|
||||
# Verify LLM client was called
|
||||
mock_llm_client.acreate_structured_output.assert_called_once_with(
|
||||
text_input="Rendered prompt with context", system_prompt=None, response_model=str
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.completion_retriever.generate_completion")
|
||||
@patch("cognee.modules.retrieval.completion_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_custom_prompt(
|
||||
self, mock_get_vector_engine, mock_generate_completion, mock_retriever
|
||||
):
|
||||
# Setup
|
||||
query = "test query with custom prompt"
|
||||
|
||||
mock_search_results = [MagicMock()]
|
||||
mock_search_results[0].payload = {"text": "This is a sample document chunk."}
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
||||
|
||||
mock_generate_completion.return_value = "Custom prompt completion response"
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0] == "Custom prompt completion response"
|
||||
|
||||
assert mock_generate_completion.call_args[1]["user_prompt_path"] == "custom_user_prompt.txt"
|
||||
assert (
|
||||
mock_generate_completion.call_args[1]["system_prompt_path"]
|
||||
== "custom_system_prompt.txt"
|
||||
)
|
||||
|
|
@ -1,236 +1,159 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return GraphCompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
||||
async def test_get_triplets_success(self, mock_brute_force_triplet_search, mock_retriever):
|
||||
mock_brute_force_triplet_search.return_value = [
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node A"}),
|
||||
attributes={"relationship_type": "connects"},
|
||||
node2=AsyncMock(attributes={"text": "Node B"}),
|
||||
)
|
||||
]
|
||||
|
||||
result = await mock_retriever.get_triplets("test query")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert result[0].attributes["relationship_type"] == "connects"
|
||||
mock_brute_force_triplet_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search")
|
||||
async def test_get_triplets_no_results(self, mock_brute_force_triplet_search, mock_retriever):
|
||||
mock_brute_force_triplet_search.return_value = []
|
||||
|
||||
with pytest.raises(NoRelevantDataFound):
|
||||
await mock_retriever.get_triplets("test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text(self, mock_retriever):
|
||||
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
|
||||
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
|
||||
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
|
||||
|
||||
triplets = [
|
||||
AsyncMock(
|
||||
node1=node_a,
|
||||
attributes={"relationship_type": "connects"},
|
||||
node2=node_b,
|
||||
),
|
||||
AsyncMock(
|
||||
node1=node_a,
|
||||
attributes={"relationship_type": "links"},
|
||||
node2=node_c,
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
|
||||
result = await mock_retriever.resolve_edges_to_text(triplets)
|
||||
|
||||
assert "Nodes:" in result
|
||||
assert "Connections:" in result
|
||||
|
||||
assert "Node: Test Title" in result
|
||||
assert "__node_content_start__" in result
|
||||
assert "Node A text content" in result
|
||||
assert "__node_content_end__" in result
|
||||
assert "Node: Node C" in result
|
||||
|
||||
assert "Test Title --[connects]--> Test Title" in result
|
||||
assert "Test Title --[links]--> Node C" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_triplets",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_get_context(self, mock_resolve_edges_to_text, mock_get_triplets, mock_retriever):
|
||||
"""Test get_context calls get_triplets and resolve_edges_to_text."""
|
||||
mock_get_triplets.return_value = ["mock_triplet"]
|
||||
mock_resolve_edges_to_text.return_value = "Mock Context"
|
||||
|
||||
result = await mock_retriever.get_context("test query")
|
||||
|
||||
assert result == "Mock Context"
|
||||
mock_get_triplets.assert_called_once_with("test query")
|
||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_triplet"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
||||
)
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
||||
async def test_get_completion_without_context(
|
||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
||||
):
|
||||
"""Test get_completion when no context is provided (calls get_context)."""
|
||||
mock_get_context.return_value = "Mock Context"
|
||||
mock_generate_completion.return_value = "Generated Completion"
|
||||
|
||||
result = await mock_retriever.get_completion("test query")
|
||||
|
||||
assert result == ["Generated Completion"]
|
||||
mock_get_context.assert_called_once_with("test query")
|
||||
mock_generate_completion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.get_context"
|
||||
)
|
||||
@patch("cognee.modules.retrieval.graph_completion_retriever.generate_completion")
|
||||
async def test_get_completion_with_context(
|
||||
self, mock_generate_completion, mock_get_context, mock_retriever
|
||||
):
|
||||
"""Test get_completion when context is provided (does not call get_context)."""
|
||||
mock_generate_completion.return_value = "Generated Completion"
|
||||
|
||||
result = await mock_retriever.get_completion("test query", context="Provided Context")
|
||||
|
||||
assert result == ["Generated Completion"]
|
||||
mock_get_context.assert_not_called()
|
||||
mock_generate_completion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
||||
async def test_get_completion_with_empty_graph(
|
||||
self,
|
||||
mock_get_default_user,
|
||||
mock_get_graph_engine,
|
||||
mock_get_llm_client,
|
||||
mock_retriever,
|
||||
):
|
||||
query = "test query with empty graph"
|
||||
|
||||
mock_graph_engine = MagicMock()
|
||||
mock_graph_engine.get_graph_data = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = ([], [])
|
||||
mock_get_graph_engine.return_value = mock_graph_engine
|
||||
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = (
|
||||
"Generated graph completion response"
|
||||
async def test_graph_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||
)
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
with pytest.raises(EntityNotFoundError):
|
||||
await mock_retriever.get_completion(query)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
mock_graph_engine.get_graph_data.assert_called_once()
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
|
||||
def test_top_n_words(self, mock_retriever):
|
||||
"""Test extraction of top frequent words from text."""
|
||||
text = "The quick brown fox jumps over the lazy dog. The fox is quick."
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
|
||||
result = mock_retriever._top_n_words(text)
|
||||
assert len(result.split(", ")) <= 3
|
||||
assert "fox" in result
|
||||
assert "quick" in result
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
|
||||
result = mock_retriever._top_n_words(text, top_n=2)
|
||||
assert len(result.split(", ")) <= 2
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
result = mock_retriever._top_n_words(text, separator=" | ")
|
||||
assert " | " in result
|
||||
await add_data_points(entities)
|
||||
|
||||
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"})
|
||||
assert "fox" not in result
|
||||
assert "quick" not in result
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
def test_get_title(self, mock_retriever):
|
||||
"""Test title generation from text."""
|
||||
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
|
||||
context = await retriever.get_context("Who works at Canva?")
|
||||
|
||||
title = mock_retriever._get_title(text)
|
||||
assert "..." in title
|
||||
assert "[" in title and "]" in title
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
title = mock_retriever._get_title(text, first_n_words=3)
|
||||
first_part = title.split("...")[0].strip()
|
||||
assert len(first_part.split()) == 3
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
title = mock_retriever._get_title(text, top_n_words=2)
|
||||
top_part = title.split("[")[1].split("]")[0]
|
||||
assert len(top_part.split(", ")) <= 2
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
def test_get_nodes(self, mock_retriever):
|
||||
"""Test node processing and deduplication."""
|
||||
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"})
|
||||
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
|
||||
node_without_attrs = AsyncMock(id="empty_node", attributes={})
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
edges = [
|
||||
AsyncMock(
|
||||
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"}
|
||||
),
|
||||
AsyncMock(
|
||||
node1=node_with_text,
|
||||
node2=node_without_attrs,
|
||||
attributes={"relationship_type": "rel2"},
|
||||
),
|
||||
AsyncMock(
|
||||
node1=node_with_name,
|
||||
node2=node_without_attrs,
|
||||
attributes={"relationship_type": "rel3"},
|
||||
),
|
||||
class Car(DataPoint):
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
|
||||
class Location(DataPoint):
|
||||
country: str
|
||||
city: str
|
||||
|
||||
class Home(DataPoint):
|
||||
location: Location
|
||||
rooms: int
|
||||
sqm: int
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: Company
|
||||
owns: Optional[list[Union[Car, Home]]] = None
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
|
||||
person1 = Person(name="Mike Rodger", works_for=company1)
|
||||
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
|
||||
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person2.owns = [
|
||||
Car(brand="Tesla", model="Model S", year=2021),
|
||||
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
|
||||
]
|
||||
|
||||
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"):
|
||||
nodes = mock_retriever._get_nodes(edges)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
|
||||
assert len(nodes) == 3
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
|
||||
|
||||
for node_id, info in nodes.items():
|
||||
assert "node" in info
|
||||
assert "name" in info
|
||||
assert "content" in info
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
|
||||
|
||||
text_node_info = nodes[node_with_text.id]
|
||||
assert text_node_info["name"] == "Generated Title"
|
||||
assert text_node_info["content"] == "This is a text node"
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
name_node_info = nodes[node_with_name.id]
|
||||
assert name_node_info["name"] == "Named Node"
|
||||
assert name_node_info["content"] == "Named Node"
|
||||
await add_data_points(entities)
|
||||
|
||||
empty_node_info = nodes[node_without_attrs.id]
|
||||
assert empty_node_info["name"] == "Unnamed Node"
|
||||
retriever = GraphCompletionRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
|
||||
print(context)
|
||||
|
||||
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with pytest.raises(DatabaseNotCreatedError):
|
||||
await retriever.get_context("Who works at Figma?")
|
||||
|
||||
await setup()
|
||||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
run(test.test_graph_completion_context_simple())
|
||||
run(test.test_graph_completion_context_complex())
|
||||
run(test.test_get_graph_completion_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphSummaryCompletionRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return GraphSummaryCompletionRetriever(system_prompt_path="test_prompt.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.completion.get_llm_client")
|
||||
@patch("cognee.modules.retrieval.utils.completion.read_query_prompt")
|
||||
@patch("cognee.modules.retrieval.utils.completion.render_prompt")
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_default_user")
|
||||
async def test_get_completion_with_custom_system_prompt(
|
||||
self,
|
||||
mock_get_default_user,
|
||||
mock_render_prompt,
|
||||
mock_read_query_prompt,
|
||||
mock_get_llm_client,
|
||||
mock_retriever,
|
||||
):
|
||||
# Setup
|
||||
query = "test query with custom prompt"
|
||||
|
||||
# Set custom system prompt
|
||||
mock_retriever.user_prompt_path = "custom_user_prompt.txt"
|
||||
mock_retriever.system_prompt_path = "custom_system_prompt.txt"
|
||||
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = (
|
||||
"Generated graph summary completion response"
|
||||
)
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query, context="test context")
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
|
||||
# Verify render_prompt was called with custom prompt path
|
||||
mock_render_prompt.assert_called_once()
|
||||
assert mock_render_prompt.call_args[0][0] == "custom_user_prompt.txt"
|
||||
|
||||
mock_read_query_prompt.assert_called_once()
|
||||
assert mock_read_query_prompt.call_args[0][0] == "custom_system_prompt.txt"
|
||||
|
||||
mock_llm_client.acreate_structured_output.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text"
|
||||
)
|
||||
@patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_resolve_edges_to_text_calls_super_and_summarizes(
|
||||
self, mock_summarize_text, mock_resolve_edges_to_text, mock_retriever
|
||||
):
|
||||
"""Test resolve_edges_to_text calls the parent method and summarizes the result."""
|
||||
|
||||
mock_resolve_edges_to_text.return_value = "Raw graph edges text"
|
||||
mock_summarize_text.return_value = "Summarized graph text"
|
||||
|
||||
result = await mock_retriever.resolve_edges_to_text(["mock_edge"])
|
||||
|
||||
mock_resolve_edges_to_text.assert_called_once_with(["mock_edge"])
|
||||
mock_summarize_text.assert_called_once_with(
|
||||
"Raw graph edges text", mock_retriever.summarize_prompt_path
|
||||
)
|
||||
|
||||
assert result == "Summarized graph text"
|
||||
|
|
@ -1,103 +1,216 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_connected_test_graph
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
import unittest
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
|
||||
|
||||
class TestInsightsRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return InsightsRetriever()
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
entityTypePerson = EntityType(
|
||||
name="Person",
|
||||
description="An individual",
|
||||
)
|
||||
|
||||
person1 = Entity(
|
||||
name="Steve Rodger",
|
||||
is_a=entityTypePerson,
|
||||
description="An American actor, comedian, and filmmaker",
|
||||
)
|
||||
|
||||
person2 = Entity(
|
||||
name="Mike Broski",
|
||||
is_a=entityTypePerson,
|
||||
description="Financial advisor and philanthropist",
|
||||
)
|
||||
|
||||
person3 = Entity(
|
||||
name="Christina Mayer",
|
||||
is_a=entityTypePerson,
|
||||
description="Maker of next generation of iconic American music videos",
|
||||
)
|
||||
|
||||
entityTypeCompany = EntityType(
|
||||
name="Company",
|
||||
description="An organization that operates on an annual basis",
|
||||
)
|
||||
|
||||
company1 = Entity(
|
||||
name="Apple",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American multinational technology company headquartered in Cupertino, California",
|
||||
)
|
||||
|
||||
company2 = Entity(
|
||||
name="Google",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American multinational technology company that specializes in Internet-related services and products",
|
||||
)
|
||||
|
||||
company3 = Entity(
|
||||
name="Facebook",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American social media, messaging, and online platform",
|
||||
)
|
||||
|
||||
entities = [person1, person2, person3, company1, company2, company3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = InsightsRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
||||
async def test_get_context_with_existing_node(self, mock_get_graph_engine, mock_retriever):
|
||||
"""Test get_context when node exists in graph."""
|
||||
mock_graph = AsyncMock()
|
||||
mock_get_graph_engine.return_value = mock_graph
|
||||
async def test_insights_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
# Mock graph response
|
||||
mock_graph.extract_node.return_value = {"id": "123"}
|
||||
mock_graph.get_connections.return_value = [
|
||||
({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})
|
||||
]
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
result = await mock_retriever.get_context("123")
|
||||
entityTypePerson = EntityType(
|
||||
name="Person",
|
||||
description="An individual",
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0][0]["id"] == "123"
|
||||
assert result[0][1]["relationship_name"] == "linked_to"
|
||||
assert result[0][2]["id"] == "456"
|
||||
mock_graph.extract_node.assert_called_once_with("123")
|
||||
mock_graph.get_connections.assert_called_once_with("123")
|
||||
person1 = Entity(
|
||||
name="Steve Rodger",
|
||||
is_a=entityTypePerson,
|
||||
description="An American actor, comedian, and filmmaker",
|
||||
)
|
||||
|
||||
person2 = Entity(
|
||||
name="Mike Broski",
|
||||
is_a=entityTypePerson,
|
||||
description="Financial advisor and philanthropist",
|
||||
)
|
||||
|
||||
person3 = Entity(
|
||||
name="Christina Mayer",
|
||||
is_a=entityTypePerson,
|
||||
description="Maker of next generation of iconic American music videos",
|
||||
)
|
||||
|
||||
person4 = Entity(
|
||||
name="Jason Statham",
|
||||
is_a=entityTypePerson,
|
||||
description="An American actor",
|
||||
)
|
||||
|
||||
person5 = Entity(
|
||||
name="Mike Tyson",
|
||||
is_a=entityTypePerson,
|
||||
description="A former professional boxer from the United States",
|
||||
)
|
||||
|
||||
entityTypeCompany = EntityType(
|
||||
name="Company",
|
||||
description="An organization that operates on an annual basis",
|
||||
)
|
||||
|
||||
company1 = Entity(
|
||||
name="Apple",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American multinational technology company headquartered in Cupertino, California",
|
||||
)
|
||||
|
||||
company2 = Entity(
|
||||
name="Google",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American multinational technology company that specializes in Internet-related services and products",
|
||||
)
|
||||
|
||||
company3 = Entity(
|
||||
name="Facebook",
|
||||
is_a=entityTypeCompany,
|
||||
description="An American social media, messaging, and online platform",
|
||||
)
|
||||
|
||||
entities = [person1, person2, person3, company1, company2, company3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await graph_engine.add_edges(
|
||||
[
|
||||
(person1.id, company1.id, "works_for"),
|
||||
(person2.id, company2.id, "works_for"),
|
||||
(person3.id, company3.id, "works_for"),
|
||||
(person4.id, company1.id, "works_for"),
|
||||
(person5.id, company1.id, "works_for"),
|
||||
]
|
||||
)
|
||||
|
||||
retriever = InsightsRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
async def test_insights_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
retriever = InsightsRetriever()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_graph_engine")
|
||||
@patch("cognee.modules.retrieval.insights_retriever.get_vector_engine")
|
||||
async def test_get_context_with_no_exact_node(
|
||||
self, mock_get_vector_engine, mock_get_graph_engine, mock_retriever
|
||||
):
|
||||
"""Test get_context when node does not exist in the graph and vector search is used."""
|
||||
mock_graph = AsyncMock()
|
||||
mock_get_graph_engine.return_value = mock_graph
|
||||
mock_graph.extract_node.return_value = None # Node does not exist
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
mock_vector = AsyncMock()
|
||||
mock_get_vector_engine.return_value = mock_vector
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("Entity_name", payload_schema=Entity)
|
||||
await vector_engine.create_collection("EntityType_name", payload_schema=EntityType)
|
||||
|
||||
mock_vector.search.side_effect = [
|
||||
[AsyncMock(id="vec_1", score=0.4)], # Entity_name search
|
||||
[AsyncMock(id="vec_2", score=0.3)], # EntityType_name search
|
||||
]
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
|
||||
mock_graph.get_connections.side_effect = lambda node_id: [
|
||||
({"id": node_id}, {"relationship_name": "related_to"}, {"id": "456"})
|
||||
]
|
||||
|
||||
result = await mock_retriever.get_context("non_existing_query")
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0][0]["id"] == "vec_1"
|
||||
assert result[0][1]["relationship_name"] == "related_to"
|
||||
assert result[0][2]["id"] == "456"
|
||||
test = TestInsightsRetriever()
|
||||
|
||||
assert result[1][0]["id"] == "vec_2"
|
||||
assert result[1][1]["relationship_name"] == "related_to"
|
||||
assert result[1][2]["id"] == "456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_with_none_query(self, mock_retriever):
|
||||
"""Test get_context with a None query (should return empty list)."""
|
||||
result = await mock_retriever.get_context(None)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_context(self, mock_retriever):
|
||||
"""Test get_completion when context is already provided."""
|
||||
test_context = [({"id": "123"}, {"relationship_name": "linked_to"}, {"id": "456"})]
|
||||
result = await mock_retriever.get_completion("test_query", context=test_context)
|
||||
assert result == test_context
|
||||
run(test.test_insights_context_simple())
|
||||
run(test.test_insights_context_complex())
|
||||
run(test.test_insights_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,196 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
|
||||
|
||||
class TestRAGCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document = TextDocument(
|
||||
name="Steve Rodger's career",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
|
||||
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
# TODO: top_k doesn't affect the output, it should be fixed.
|
||||
retriever = CompletionRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rag_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("DocumentChunk_text", payload_schema=DocumentChunk)
|
||||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == "", "Returned context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestRAGCompletionRetriever()
|
||||
|
||||
run(test.test_rag_completion_context_simple())
|
||||
run(test.test_rag_completion_context_complex())
|
||||
run(test.test_get_rag_completion_context_on_empty_graph())
|
||||
|
|
@ -1,122 +1,168 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
|
||||
class TestSummariesRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self):
|
||||
return SummariesRetriever()
|
||||
class TextSummariesRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_context(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
document1 = TextDocument(
|
||||
name="Employee List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
document2 = TextDocument(
|
||||
name="Car List",
|
||||
raw_data_location="somewhere",
|
||||
external_metadata="",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
chunk1 = DocumentChunk(
|
||||
text="Steve Rodger",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk1_summary = TextSummary(
|
||||
text="S.R.",
|
||||
made_from=chunk1,
|
||||
)
|
||||
chunk2 = DocumentChunk(
|
||||
text="Mike Broski",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk2_summary = TextSummary(
|
||||
text="M.B.",
|
||||
made_from=chunk2,
|
||||
)
|
||||
chunk3 = DocumentChunk(
|
||||
text="Christina Mayer",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document1,
|
||||
contains=[],
|
||||
)
|
||||
chunk3_summary = TextSummary(
|
||||
text="C.M.",
|
||||
made_from=chunk3,
|
||||
)
|
||||
chunk4 = DocumentChunk(
|
||||
text="Range Rover",
|
||||
chunk_size=2,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk4_summary = TextSummary(
|
||||
text="R.R.",
|
||||
made_from=chunk4,
|
||||
)
|
||||
chunk5 = DocumentChunk(
|
||||
text="Hyundai",
|
||||
chunk_size=2,
|
||||
chunk_index=1,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk5_summary = TextSummary(
|
||||
text="H.Y.",
|
||||
made_from=chunk5,
|
||||
)
|
||||
chunk6 = DocumentChunk(
|
||||
text="Chrysler",
|
||||
chunk_size=2,
|
||||
chunk_index=2,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=document2,
|
||||
contains=[],
|
||||
)
|
||||
chunk6_summary = TextSummary(
|
||||
text="C.H.",
|
||||
made_from=chunk6,
|
||||
)
|
||||
|
||||
entities = [
|
||||
chunk1_summary,
|
||||
chunk2_summary,
|
||||
chunk3_summary,
|
||||
chunk4_summary,
|
||||
chunk5_summary,
|
||||
chunk6_summary,
|
||||
]
|
||||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = SummariesRetriever(limit=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query"
|
||||
doc_id1 = str(uuid.uuid4())
|
||||
doc_id2 = str(uuid.uuid4())
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
# Mock search results
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.95,
|
||||
"payload": {
|
||||
"text": "This is the first summary.",
|
||||
"document_id": doc_id1,
|
||||
"metadata": {"title": "Document 1"},
|
||||
},
|
||||
}
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.85,
|
||||
"payload": {
|
||||
"text": "This is the second summary.",
|
||||
"document_id": doc_id2,
|
||||
"metadata": {"title": "Document 2"},
|
||||
},
|
||||
}
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
mock_search_results = [mock_result_1, mock_result_2]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
with pytest.raises(NoDataError):
|
||||
await retriever.get_context("Christina Mayer")
|
||||
|
||||
# Verify
|
||||
assert len(results) == 2
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary)
|
||||
|
||||
# Check first result
|
||||
assert results[0]["payload"]["text"] == "This is the first summary."
|
||||
assert results[0]["payload"]["document_id"] == doc_id1
|
||||
assert results[0]["payload"]["metadata"]["title"] == "Document 1"
|
||||
assert results[0]["score"] == 0.95
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
|
||||
# Check second result
|
||||
assert results[1]["payload"]["text"] == "This is the second summary."
|
||||
assert results[1]["payload"]["document_id"] == doc_id2
|
||||
assert results[1]["payload"]["metadata"]["title"] == "Document 2"
|
||||
assert results[1]["score"] == 0.85
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_empty_results(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with no results"
|
||||
mock_search_results = []
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
test = TextSummariesRetriever()
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 0
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.summaries_retriever.get_vector_engine")
|
||||
async def test_get_completion_with_custom_limit(self, mock_get_vector_engine, mock_retriever):
|
||||
# Setup
|
||||
query = "test query with custom limit"
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Mock search results
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"score": 0.95,
|
||||
"payload": {
|
||||
"text": "This is a summary.",
|
||||
"document_id": doc_id,
|
||||
"metadata": {"title": "Document 1"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_search_results = [mock_result]
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.search.return_value = mock_search_results
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
# Set custom limit
|
||||
mock_retriever.limit = 10
|
||||
|
||||
# Execute
|
||||
results = await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify
|
||||
assert len(results) == 1
|
||||
assert results[0]["payload"]["text"] == "This is a summary."
|
||||
|
||||
# Verify search was called with custom limit
|
||||
mock_vector_engine.search.assert_called_once_with("TextSummary_text", query, limit=10)
|
||||
run(test.test_chunk_context())
|
||||
run(test.test_chunk_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import pytest
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_search,
|
||||
brute_force_triplet_search,
|
||||
)
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,13 +20,11 @@ async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
|
|||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(CollectionDistancesNotFoundError):
|
||||
await brute_force_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||
|
|
@ -40,9 +38,7 @@ async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_e
|
|||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(CollectionDistancesNotFoundError):
|
||||
await brute_force_triplet_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue