fix: min to days
This commit is contained in:
commit
85a2bac062
55 changed files with 1948 additions and 278 deletions
|
|
@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False
|
|||
# Vector: LanceDB
|
||||
# Graph: KuzuDB
|
||||
#
|
||||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
|
||||
# Disable mode when using not supported graph/vector databases.
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=True
|
||||
|
||||
################################################################################
|
||||
# ☁️ Cloud Sync Settings
|
||||
|
|
|
|||
41
.github/workflows/e2e_tests.yml
vendored
41
.github/workflows/e2e_tests.yml
vendored
|
|
@ -447,3 +447,44 @@ jobs:
|
|||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_conversation_history.py
|
||||
|
||||
test-load:
|
||||
name: Test Load
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "aws"
|
||||
|
||||
- name: Set File Descriptor Limit
|
||||
run: sudo prlimit --pid $$ --nofile=4096:4096
|
||||
|
||||
- name: Verify File Descriptor Limit
|
||||
run: ulimit -n
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Load Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: True
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
STORAGE_BACKEND: s3
|
||||
AWS_REGION: eu-west-1
|
||||
AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }}
|
||||
run: uv run python ./cognee/tests/test_load.py
|
||||
25
.github/workflows/examples_tests.yml
vendored
25
.github/workflows/examples_tests.yml
vendored
|
|
@ -210,6 +210,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/memify_coding_agent_example.py
|
||||
|
||||
test-custom-pipeline:
|
||||
name: Run Custom Pipeline Example
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Custom Pipeline Example
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/run_custom_pipeline_example.py
|
||||
|
||||
test-permissions-example:
|
||||
name: Run Permissions Example
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
3
.github/workflows/search_db_tests.yml
vendored
3
.github/workflows/search_db_tests.yml
vendored
|
|
@ -84,6 +84,7 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
|
@ -135,6 +136,7 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
@ -197,6 +199,7 @@ jobs:
|
|||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ on:
|
|||
required: false
|
||||
type: string
|
||||
default: '["3.10.x", "3.12.x", "3.13.x"]'
|
||||
os:
|
||||
required: false
|
||||
type: string
|
||||
default: '["ubuntu-22.04", "macos-15", "windows-latest"]'
|
||||
secrets:
|
||||
LLM_PROVIDER:
|
||||
required: true
|
||||
|
|
@ -40,10 +44,11 @@ jobs:
|
|||
run-unit-tests:
|
||||
name: Unit tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -76,10 +81,11 @@ jobs:
|
|||
run-integration-tests:
|
||||
name: Integration tests ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -112,10 +118,11 @@ jobs:
|
|||
run-library-test:
|
||||
name: Library test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -148,10 +155,11 @@ jobs:
|
|||
run-build-test:
|
||||
name: Build test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -177,10 +185,11 @@ jobs:
|
|||
run-soft-deletion-test:
|
||||
name: Soft Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -214,10 +223,11 @@ jobs:
|
|||
run-hard-deletion-test:
|
||||
name: Hard Delete test ${{ matrix.python-version }} on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
os: ${{ fromJSON(inputs.os) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
4
.github/workflows/test_ollama.yml
vendored
4
.github/workflows/test_ollama.yml
vendored
|
|
@ -75,7 +75,7 @@ jobs:
|
|||
{ "role": "user", "content": "Whatever I say, answer with Yes." }
|
||||
]
|
||||
}'
|
||||
curl -X POST http://127.0.0.1:11434/v1/embeddings \
|
||||
curl -X POST http://127.0.0.1:11434/api/embed \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "avr/sfr-embedding-mistral:latest",
|
||||
|
|
@ -98,7 +98,7 @@ jobs:
|
|||
LLM_MODEL: "phi4"
|
||||
EMBEDDING_PROVIDER: "ollama"
|
||||
EMBEDDING_MODEL: "avr/sfr-embedding-mistral:latest"
|
||||
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embeddings"
|
||||
EMBEDDING_ENDPOINT: "http://localhost:11434/api/embed"
|
||||
EMBEDDING_DIMENSIONS: "4096"
|
||||
HUGGINGFACE_TOKENIZER: "Salesforce/SFR-Embedding-Mistral"
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
|
|
|||
25
.github/workflows/test_suites.yml
vendored
25
.github/workflows/test_suites.yml
vendored
|
|
@ -1,4 +1,6 @@
|
|||
name: Test Suites
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -80,12 +82,22 @@ jobs:
|
|||
uses: ./.github/workflows/notebooks_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
different-operating-systems-tests:
|
||||
name: Operating System and Python Tests
|
||||
different-os-tests-basic:
|
||||
name: OS and Python Tests Ubuntu
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
|
||||
os: '["ubuntu-22.04"]'
|
||||
secrets: inherit
|
||||
|
||||
different-os-tests-extended:
|
||||
name: OS and Python Tests Extended
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.13.x"]'
|
||||
os: '["macos-15", "windows-latest"]'
|
||||
secrets: inherit
|
||||
|
||||
# Matrix-based vector database tests
|
||||
|
|
@ -135,7 +147,8 @@ jobs:
|
|||
e2e-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
llm-tests,
|
||||
|
|
@ -155,7 +168,8 @@ jobs:
|
|||
cli-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
different-os-tests-basic,
|
||||
different-os-tests-extended,
|
||||
vector-db-tests,
|
||||
example-tests,
|
||||
db-examples-tests,
|
||||
|
|
@ -176,7 +190,8 @@ jobs:
|
|||
"${{ needs.cli-tests.result }}" == "success" &&
|
||||
"${{ needs.graph-db-tests.result }}" == "success" &&
|
||||
"${{ needs.notebook-tests.result }}" == "success" &&
|
||||
"${{ needs.different-operating-systems-tests.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-basic.result }}" == "success" &&
|
||||
"${{ needs.different-os-tests-extended.result }}" == "success" &&
|
||||
"${{ needs.vector-db-tests.result }}" == "success" &&
|
||||
"${{ needs.example-tests.result }}" == "success" &&
|
||||
"${{ needs.db-examples-tests.result }}" == "success" &&
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .api.v1.add import add
|
|||
from .api.v1.delete import delete
|
||||
from .api.v1.cognify import cognify
|
||||
from .modules.memify import memify
|
||||
from .modules.run_custom_pipeline import run_custom_pipeline
|
||||
from .api.v1.update import update
|
||||
from .api.v1.config.config import config
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ from cognee.api.v1.users.routers import (
|
|||
)
|
||||
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
|
||||
|
||||
# Ensure application logging is configured for container stdout/stderr
|
||||
setup_logging()
|
||||
logger = get_logger()
|
||||
|
||||
if os.getenv("ENV", "prod") == "prod":
|
||||
|
|
@ -74,6 +76,9 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
await get_default_user()
|
||||
|
||||
# Emit a clear startup message for docker logs
|
||||
logger.info("Backend server has started")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
|
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
vector_dbs_with_multi_user_support = ["lancedb"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
|
||||
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
|
||||
)
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if not backend_access_control_enabled():
|
||||
return
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
|
|
|||
|
|
@ -133,6 +133,6 @@ def create_vector_engine(
|
|||
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider: {vector_db_provider}. "
|
||||
f"Unsupported vector database provider: {vector_db_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
|
|
@ -18,9 +18,21 @@ class Edge(BaseModel):
|
|||
|
||||
# Mixed usage
|
||||
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
|
||||
|
||||
# With edge_text for rich embedding representation
|
||||
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
|
||||
"""
|
||||
|
||||
weight: Optional[float] = None
|
||||
weights: Optional[Dict[str, float]] = None
|
||||
relationship_type: Optional[str] = None
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
edge_text: Optional[str] = None
|
||||
|
||||
@field_validator("edge_text", mode="before")
|
||||
@classmethod
|
||||
def ensure_edge_text(cls, v, info):
|
||||
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
|
||||
if v is None and info.data.get("relationship_type"):
|
||||
return info.data["relationship_type"]
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from cognee import memify
|
||||
from cognee.context_global_variables import (
|
||||
set_database_global_context_variables,
|
||||
set_session_user_context_variable,
|
||||
)
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.memify import extract_user_sessions, cognify_session
|
||||
|
||||
|
||||
logger = get_logger("persist_sessions_in_knowledge_graph")
|
||||
|
||||
|
||||
async def persist_sessions_in_knowledge_graph_pipeline(
|
||||
user: User,
|
||||
session_ids: Optional[List[str]] = None,
|
||||
dataset: str = "main_dataset",
|
||||
run_in_background: bool = False,
|
||||
):
|
||||
await set_session_user_context_variable(user)
|
||||
dataset_to_write = await get_authorized_existing_datasets(
|
||||
user=user, datasets=[dataset], permission_type="write"
|
||||
)
|
||||
|
||||
if not dataset_to_write:
|
||||
raise CogneeValidationError(
|
||||
message=f"User (id: {str(user.id)}) does not have write access to dataset: {dataset}",
|
||||
log=False,
|
||||
)
|
||||
|
||||
await set_database_global_context_variables(
|
||||
dataset_to_write[0].id, dataset_to_write[0].owner_id
|
||||
)
|
||||
|
||||
extraction_tasks = [Task(extract_user_sessions, session_ids=session_ids)]
|
||||
|
||||
enrichment_tasks = [
|
||||
Task(cognify_session, dataset_id=dataset_to_write[0].id),
|
||||
]
|
||||
|
||||
result = await memify(
|
||||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
dataset=dataset_to_write[0].id,
|
||||
data=[{}],
|
||||
run_in_background=run_in_background,
|
||||
)
|
||||
|
||||
logger.info("Session persistence pipeline completed")
|
||||
return result
|
||||
|
|
@ -3,6 +3,7 @@ from typing import List, Union
|
|||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.tasks.temporal_graph.models import Event
|
||||
|
|
@ -24,7 +25,6 @@ class DocumentChunk(DataPoint):
|
|||
- cut_type: The type of cut that defined this chunk.
|
||||
- is_part_of: The document to which this chunk belongs.
|
||||
- contains: A list of entities or events contained within the chunk (default is None).
|
||||
- last_accessed_at: The timestamp of the last time the chunk was accessed.
|
||||
- metadata: A dictionary to hold meta information related to the chunk, including index
|
||||
fields.
|
||||
"""
|
||||
|
|
@ -34,5 +34,5 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
124
cognee/modules/chunking/text_chunker_with_overlap.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextChunkerWithOverlap(Chunker):
|
||||
def __init__(
|
||||
self,
|
||||
document,
|
||||
get_text: callable,
|
||||
max_chunk_size: int,
|
||||
chunk_overlap_ratio: float = 0.0,
|
||||
get_chunk_data: callable = None,
|
||||
):
|
||||
super().__init__(document, get_text, max_chunk_size)
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
self.chunk_overlap_ratio = chunk_overlap_ratio
|
||||
self.chunk_overlap = int(max_chunk_size * chunk_overlap_ratio)
|
||||
|
||||
if get_chunk_data is not None:
|
||||
self.get_chunk_data = get_chunk_data
|
||||
elif chunk_overlap_ratio > 0:
|
||||
paragraph_max_size = int(0.5 * chunk_overlap_ratio * max_chunk_size)
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, paragraph_max_size, batch_paragraphs=True
|
||||
)
|
||||
else:
|
||||
self.get_chunk_data = lambda text: chunk_by_paragraph(
|
||||
text, self.max_chunk_size, batch_paragraphs=True
|
||||
)
|
||||
|
||||
def _accumulation_overflows(self, chunk_data):
|
||||
"""Check if adding chunk_data would exceed max_chunk_size."""
|
||||
return self._accumulated_size + chunk_data["chunk_size"] > self.max_chunk_size
|
||||
|
||||
def _accumulate_chunk_data(self, chunk_data):
|
||||
"""Add chunk_data to the current accumulation."""
|
||||
self._accumulated_chunk_data.append(chunk_data)
|
||||
self._accumulated_size += chunk_data["chunk_size"]
|
||||
|
||||
def _clear_accumulation(self):
|
||||
"""Reset accumulation, keeping overlap chunk_data based on chunk_overlap_ratio."""
|
||||
if self.chunk_overlap == 0:
|
||||
self._accumulated_chunk_data = []
|
||||
self._accumulated_size = 0
|
||||
return
|
||||
|
||||
# Keep chunk_data from the end that fit in overlap
|
||||
overlap_chunk_data = []
|
||||
overlap_size = 0
|
||||
|
||||
for chunk_data in reversed(self._accumulated_chunk_data):
|
||||
if overlap_size + chunk_data["chunk_size"] <= self.chunk_overlap:
|
||||
overlap_chunk_data.insert(0, chunk_data)
|
||||
overlap_size += chunk_data["chunk_size"]
|
||||
else:
|
||||
break
|
||||
|
||||
self._accumulated_chunk_data = overlap_chunk_data
|
||||
self._accumulated_size = overlap_size
|
||||
|
||||
def _create_chunk(self, text, size, cut_type, chunk_id=None):
|
||||
"""Create a DocumentChunk with standard metadata."""
|
||||
try:
|
||||
return DocumentChunk(
|
||||
id=chunk_id or uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text=text,
|
||||
chunk_size=size,
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=cut_type,
|
||||
contains=[],
|
||||
metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
def _create_chunk_from_accumulation(self):
|
||||
"""Create a DocumentChunk from current accumulated chunk_data."""
|
||||
chunk_text = " ".join(chunk["text"] for chunk in self._accumulated_chunk_data)
|
||||
return self._create_chunk(
|
||||
text=chunk_text,
|
||||
size=self._accumulated_size,
|
||||
cut_type=self._accumulated_chunk_data[-1]["cut_type"],
|
||||
)
|
||||
|
||||
def _emit_chunk(self, chunk_data):
|
||||
"""Emit a chunk when accumulation overflows."""
|
||||
if len(self._accumulated_chunk_data) > 0:
|
||||
chunk = self._create_chunk_from_accumulation()
|
||||
self._clear_accumulation()
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
else:
|
||||
# Handle single chunk_data exceeding max_chunk_size
|
||||
chunk = self._create_chunk(
|
||||
text=chunk_data["text"],
|
||||
size=chunk_data["chunk_size"],
|
||||
cut_type=chunk_data["cut_type"],
|
||||
chunk_id=chunk_data["chunk_id"],
|
||||
)
|
||||
|
||||
self.chunk_index += 1
|
||||
return chunk
|
||||
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
for chunk_data in self.get_chunk_data(content_text):
|
||||
if not self._accumulation_overflows(chunk_data):
|
||||
self._accumulate_chunk_data(chunk_data)
|
||||
continue
|
||||
|
||||
yield self._emit_chunk(chunk_data)
|
||||
|
||||
if len(self._accumulated_chunk_data) == 0:
|
||||
return
|
||||
|
||||
yield self._create_chunk_from_accumulation()
|
||||
|
|
@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
distance = embedding_map.get(relationship_type, None)
|
||||
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
||||
"relationship_type"
|
||||
)
|
||||
distance = embedding_map.get(edge_key, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import (
|
||||
|
|
@ -243,10 +244,26 @@ def _process_graph_nodes(
|
|||
ontology_relationships,
|
||||
)
|
||||
|
||||
# Add entity to data chunk
|
||||
if data_chunk.contains is None:
|
||||
data_chunk.contains = []
|
||||
data_chunk.contains.append(entity_node)
|
||||
|
||||
edge_text = "; ".join(
|
||||
[
|
||||
"relationship_name: contains",
|
||||
f"entity_name: {entity_node.name}",
|
||||
f"entity_description: {entity_node.description}",
|
||||
]
|
||||
)
|
||||
|
||||
data_chunk.contains.append(
|
||||
(
|
||||
Edge(
|
||||
relationship_type="contains",
|
||||
edge_text=edge_text,
|
||||
),
|
||||
entity_node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _process_graph_edges(
|
||||
|
|
|
|||
|
|
@ -1,71 +1,70 @@
|
|||
import string
|
||||
from typing import List
|
||||
from collections import Counter
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
|
||||
def _get_top_n_frequent_words(
|
||||
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
|
||||
) -> str:
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
return separator.join(top_words)
|
||||
|
||||
|
||||
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
"""Creates a title by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
|
||||
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id in nodes:
|
||||
continue
|
||||
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _create_title_from_text(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
nodes = _extract_nodes_from_edges(retrieved_edges)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- retrieved_edges (list): A list of edges retrieved from the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
|
||||
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
import string
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
from collections import Counter
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _top_n_words(text, top_n=first_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = _get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
nodes = _get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
|
||||
connections = []
|
||||
for edge in retrieved_edges:
|
||||
source_name = nodes[edge.node1.id]["name"]
|
||||
target_name = nodes[edge.node2.id]["name"]
|
||||
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
|
||||
|
||||
connection_section = "\n".join(connections)
|
||||
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def get_memory_fragment(
|
|||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
|
|
|||
1
cognee/modules/run_custom_pipeline/__init__.py
Normal file
1
cognee/modules/run_custom_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .run_custom_pipeline import run_custom_pipeline
|
||||
69
cognee/modules/run_custom_pipeline/run_custom_pipeline.py
Normal file
69
cognee/modules/run_custom_pipeline/run_custom_pipeline.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Union, Optional, List, Type, Any
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.pipelines import run_pipeline
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def run_custom_pipeline(
|
||||
tasks: Union[List[Task], List[str]] = None,
|
||||
data: Any = None,
|
||||
dataset: Union[str, UUID] = "main_dataset",
|
||||
user: User = None,
|
||||
vector_db_config: Optional[dict] = None,
|
||||
graph_db_config: Optional[dict] = None,
|
||||
data_per_batch: int = 20,
|
||||
run_in_background: bool = False,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
):
|
||||
"""
|
||||
Custom pipeline in Cognee, can work with already built graphs. Data needs to be provided which can be processed
|
||||
with provided tasks.
|
||||
|
||||
Provided tasks and data will be arranged to run the Cognee pipeline and execute graph enrichment/creation.
|
||||
|
||||
This is the core processing step in Cognee that converts raw text and documents
|
||||
into an intelligent knowledge graph. It analyzes content, extracts entities and
|
||||
relationships, and creates semantic connections for enhanced search and reasoning.
|
||||
|
||||
Args:
|
||||
tasks: List of Cognee Tasks to execute.
|
||||
data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used.
|
||||
Data provided here will be forwarded to the first extraction task in the pipeline as input.
|
||||
dataset: Dataset name or dataset uuid to process.
|
||||
user: User context for authentication and data access. Uses default if None.
|
||||
vector_db_config: Custom vector database configuration for embeddings storage.
|
||||
graph_db_config: Custom graph database configuration for relationship storage.
|
||||
data_per_batch: Number of data items to be processed in parallel.
|
||||
run_in_background: If True, starts processing asynchronously and returns immediately.
|
||||
If False, waits for completion before returning.
|
||||
Background mode recommended for large datasets (>100MB).
|
||||
Use pipeline_run_id from return value to monitor progress.
|
||||
"""
|
||||
|
||||
custom_tasks = [
|
||||
*tasks,
|
||||
]
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
||||
# Run the run_pipeline in the background or blocking based on executor
|
||||
return await pipeline_executor_func(
|
||||
pipeline=run_pipeline,
|
||||
tasks=custom_tasks,
|
||||
user=user,
|
||||
data=data,
|
||||
datasets=dataset,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=False,
|
||||
data_per_batch=data_per_batch,
|
||||
pipeline_name=pipeline_name,
|
||||
)
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
|
@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -74,7 +74,7 @@ async def search(
|
|||
)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if backend_access_control_enabled():
|
||||
search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -156,7 +156,7 @@ async def search(
|
|||
)
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from ..models import User
|
|||
from ..get_fastapi_users import get_fastapi_users
|
||||
from .get_default_user import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
|
||||
logger = get_logger("get_authenticated_user")
|
||||
|
|
@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
|
|||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
|
||||
or backend_access_control_enabled()
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
|||
from cognee.modules.users.methods.create_default_user import create_default_user
|
||||
|
||||
|
||||
async def get_default_user() -> SimpleNamespace:
|
||||
async def get_default_user() -> User:
|
||||
db_engine = get_relational_engine()
|
||||
base_config = get_base_config()
|
||||
default_email = base_config.default_user_email or "default_user@example.com"
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from .extract_subgraph import extract_subgraph
|
||||
from .extract_subgraph_chunks import extract_subgraph_chunks
|
||||
from .cognify_session import cognify_session
|
||||
from .extract_user_sessions import extract_user_sessions
|
||||
|
|
|
|||
41
cognee/tasks/memify/cognify_session.py
Normal file
41
cognee/tasks/memify/cognify_session.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import cognee
|
||||
|
||||
from cognee.exceptions import CogneeValidationError, CogneeSystemError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("cognify_session")
|
||||
|
||||
|
||||
async def cognify_session(data, dataset_id=None):
|
||||
"""
|
||||
Process and cognify session data into the knowledge graph.
|
||||
|
||||
Adds session content to cognee with a dedicated "user_sessions" node set,
|
||||
then triggers the cognify pipeline to extract entities and relationships
|
||||
from the session data.
|
||||
|
||||
Args:
|
||||
data: Session string containing Question, Context, and Answer information.
|
||||
dataset_name: Name of dataset.
|
||||
|
||||
Raises:
|
||||
CogneeValidationError: If data is None or empty.
|
||||
CogneeSystemError: If cognee operations fail.
|
||||
"""
|
||||
try:
|
||||
if not data or (isinstance(data, str) and not data.strip()):
|
||||
logger.warning("Empty session data provided to cognify_session task, skipping")
|
||||
raise CogneeValidationError(message="Session data cannot be empty", log=False)
|
||||
|
||||
logger.info("Processing session data for cognification")
|
||||
|
||||
await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"])
|
||||
logger.debug("Session data added to cognee with node_set: user_sessions")
|
||||
await cognee.cognify(datasets=[dataset_id])
|
||||
logger.info("Session data successfully cognified")
|
||||
|
||||
except CogneeValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error cognifying session data: {str(e)}")
|
||||
raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False)
|
||||
73
cognee/tasks/memify/extract_user_sessions.py
Normal file
73
cognee/tasks/memify/extract_user_sessions.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.exceptions import CogneeSystemError
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("extract_user_sessions")
|
||||
|
||||
|
||||
async def extract_user_sessions(
|
||||
data,
|
||||
session_ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Extract Q&A sessions for the current user from cache.
|
||||
|
||||
Retrieves all Q&A triplets from specified session IDs and yields them
|
||||
as formatted strings combining question, context, and answer.
|
||||
|
||||
Args:
|
||||
data: Data passed from memify. If empty dict ({}), no external data is provided.
|
||||
session_ids: Optional list of specific session IDs to extract.
|
||||
|
||||
Yields:
|
||||
String containing session ID and all Q&A pairs formatted.
|
||||
|
||||
Raises:
|
||||
CogneeSystemError: If cache engine is unavailable or extraction fails.
|
||||
"""
|
||||
try:
|
||||
if not data or data == [{}]:
|
||||
logger.info("Fetching session metadata for current user")
|
||||
|
||||
user: User = session_user.get()
|
||||
if not user:
|
||||
raise CogneeSystemError(message="No authenticated user found in context", log=False)
|
||||
|
||||
user_id = str(user.id)
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
if cache_engine is None:
|
||||
raise CogneeSystemError(
|
||||
message="Cache engine not available for session extraction, please enable caching in order to have sessions to save",
|
||||
log=False,
|
||||
)
|
||||
|
||||
if session_ids:
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
qa_data = await cache_engine.get_all_qas(user_id, session_id)
|
||||
if qa_data:
|
||||
logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs")
|
||||
session_string = f"Session ID: {session_id}\n\n"
|
||||
for qa_pair in qa_data:
|
||||
question = qa_pair.get("question", "")
|
||||
answer = qa_pair.get("answer", "")
|
||||
session_string += f"Question: {question}\n\nAnswer: {answer}\n\n"
|
||||
yield session_string
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract session {session_id}: {str(e)}")
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
"No specific session_ids provided. Please specify which sessions to extract."
|
||||
)
|
||||
|
||||
except CogneeSystemError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting user sessions: {str(e)}")
|
||||
raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False)
|
||||
|
|
@ -8,47 +8,58 @@ logger = get_logger("index_data_points")
|
|||
|
||||
|
||||
async def index_data_points(data_points: list[DataPoint]):
|
||||
created_indexes = {}
|
||||
index_points = {}
|
||||
"""Index data points in the vector engine by creating embeddings for specified fields.
|
||||
|
||||
Process:
|
||||
1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
|
||||
2. Creates vector indexes for each (type, field) combination on first encounter
|
||||
3. Batches points per (type, field) and creates async indexing tasks
|
||||
4. Executes all indexing tasks in parallel for efficient embedding generation
|
||||
|
||||
Args:
|
||||
data_points: List of DataPoint objects to index. Each DataPoint's metadata must
|
||||
contain an 'index_fields' list specifying which fields to embed.
|
||||
|
||||
Returns:
|
||||
The original data_points list.
|
||||
"""
|
||||
data_points_by_type = {}
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
for data_point in data_points:
|
||||
data_point_type = type(data_point)
|
||||
type_name = data_point_type.__name__
|
||||
|
||||
for field_name in data_point.metadata["index_fields"]:
|
||||
if getattr(data_point, field_name, None) is None:
|
||||
continue
|
||||
|
||||
index_name = f"{data_point_type.__name__}_{field_name}"
|
||||
if type_name not in data_points_by_type:
|
||||
data_points_by_type[type_name] = {}
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
if field_name not in data_points_by_type[type_name]:
|
||||
await vector_engine.create_vector_index(type_name, field_name)
|
||||
data_points_by_type[type_name][field_name] = []
|
||||
|
||||
indexed_data_point = data_point.model_copy()
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
data_points_by_type[type_name][field_name].append(indexed_data_point)
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
batch_size = vector_engine.embedding_engine.get_batch_size()
|
||||
|
||||
for index_name_and_field, points in index_points.items():
|
||||
first = index_name_and_field.index("_")
|
||||
index_name = index_name_and_field[:first]
|
||||
field_name = index_name_and_field[first + 1 :]
|
||||
batches = (
|
||||
(type_name, field_name, points[i : i + batch_size])
|
||||
for type_name, fields in data_points_by_type.items()
|
||||
for field_name, points in fields.items()
|
||||
for i in range(0, len(points), batch_size)
|
||||
)
|
||||
|
||||
# Create embedding requests per batch to run in parallel later
|
||||
for i in range(0, len(points), batch_size):
|
||||
batch = points[i : i + batch_size]
|
||||
tasks.append(
|
||||
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
|
||||
)
|
||||
tasks = [
|
||||
asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
|
||||
for type_name, field_name, batch_points in batches
|
||||
]
|
||||
|
||||
# Run all embedding requests in parallel
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return data_points
|
||||
|
|
|
|||
|
|
@ -1,17 +1,44 @@
|
|||
import asyncio
|
||||
from collections import Counter
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import Counter
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.models.EdgeType import EdgeType
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
|
||||
from cognee.tasks.storage.index_data_points import index_data_points
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _get_edge_text(item: dict) -> str:
|
||||
"""Extract edge text for embedding - prefers edge_text field with fallback."""
|
||||
if "edge_text" in item:
|
||||
return item["edge_text"]
|
||||
|
||||
if "relationship_name" in item:
|
||||
return item["relationship_name"]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
|
||||
"""Transform raw edge data into EdgeType datapoints."""
|
||||
edge_texts = [
|
||||
_get_edge_text(item)
|
||||
for edge in edges_data
|
||||
for item in edge
|
||||
if isinstance(item, dict) and "relationship_name" in item
|
||||
]
|
||||
|
||||
edge_types = Counter(edge_texts)
|
||||
|
||||
return [
|
||||
EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
|
||||
for text, count in edge_types.items()
|
||||
]
|
||||
|
||||
|
||||
async def index_graph_edges(
|
||||
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
|
||||
):
|
||||
|
|
@ -23,24 +50,17 @@ async def index_graph_edges(
|
|||
the `relationship_name` field.
|
||||
|
||||
Steps:
|
||||
1. Initialize the vector engine and graph engine.
|
||||
2. Retrieve graph edge data and count relationship types (`relationship_name`).
|
||||
3. Create vector indexes for `relationship_name` if they don't exist.
|
||||
4. Transform the counted relationships into `EdgeType` objects.
|
||||
5. Index the transformed data points in the vector engine.
|
||||
1. Initialize the graph engine if needed and retrieve edge data.
|
||||
2. Transform edge data into EdgeType datapoints.
|
||||
3. Index the EdgeType datapoints using the standard indexing function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If initialization of the vector engine or graph engine fails.
|
||||
RuntimeError: If initialization of the graph engine fails.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
created_indexes = {}
|
||||
index_points = {}
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if edges_data is None:
|
||||
graph_engine = await get_graph_engine()
|
||||
_, edges_data = await graph_engine.get_graph_data()
|
||||
|
|
@ -51,47 +71,7 @@ async def index_graph_edges(
|
|||
logger.error("Failed to initialize engines: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
edge_types = Counter(
|
||||
item.get("relationship_name")
|
||||
for edge in edges_data
|
||||
for item in edge
|
||||
if isinstance(item, dict) and "relationship_name" in item
|
||||
)
|
||||
|
||||
for text, count in edge_types.items():
|
||||
edge = EdgeType(
|
||||
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
|
||||
)
|
||||
data_point_type = type(edge)
|
||||
|
||||
for field_name in edge.metadata["index_fields"]:
|
||||
index_name = f"{data_point_type.__name__}.{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
|
||||
indexed_data_point = edge.model_copy()
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
# Get maximum batch size for embedding model
|
||||
batch_size = vector_engine.embedding_engine.get_batch_size()
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
index_name, field_name = index_name.split(".")
|
||||
|
||||
# Create embedding tasks to run in parallel later
|
||||
for start in range(0, len(indexable_points), batch_size):
|
||||
batch = indexable_points[start : start + batch_size]
|
||||
|
||||
tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
|
||||
|
||||
# Start all embedding tasks and wait for completion
|
||||
await asyncio.gather(*tasks)
|
||||
edge_type_datapoints = create_edge_type_datapoints(edges_data)
|
||||
await index_data_points(edge_type_datapoints)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ async def main():
|
|||
|
||||
answer = await cognee.search("Do programmers change light bulbs?")
|
||||
assert len(answer) != 0
|
||||
lowercase_answer = answer[0].lower()
|
||||
lowercase_answer = answer[0]["search_result"][0].lower()
|
||||
assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
|
||||
|
||||
answer = await cognee.search("What colours are there in the presentation table?")
|
||||
assert len(answer) != 0
|
||||
lowercase_answer = answer[0].lower()
|
||||
lowercase_answer = answer[0]["search_result"][0].lower()
|
||||
assert (
|
||||
("red" in lowercase_answer)
|
||||
and ("blue" in lowercase_answer)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ import cognee
|
|||
import pathlib
|
||||
|
||||
from cognee.infrastructure.databases.cache import get_cache_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from collections import Counter
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -54,10 +56,10 @@ async def main():
|
|||
"""DataCo is a data analytics company. They help businesses make sense of their data."""
|
||||
)
|
||||
|
||||
await cognee.add(text_1, dataset_name)
|
||||
await cognee.add(text_2, dataset_name)
|
||||
await cognee.add(data=text_1, dataset_name=dataset_name)
|
||||
await cognee.add(data=text_2, dataset_name=dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
await cognee.cognify(datasets=[dataset_name])
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -188,7 +190,6 @@ async def main():
|
|||
f"GRAPH_SUMMARY_COMPLETION should return non-empty list, got: {result_summary!r}"
|
||||
)
|
||||
|
||||
# Verify saved
|
||||
history_summary = await cache_engine.get_latest_qa(str(user.id), session_id_summary, last_n=10)
|
||||
our_qa_summary = [
|
||||
h for h in history_summary if h["question"] == "What are the key points about TechCorp?"
|
||||
|
|
@ -228,6 +229,46 @@ async def main():
|
|||
assert "CONTEXT:" in formatted_history, "Formatted history should contain 'CONTEXT:' prefix"
|
||||
assert "ANSWER:" in formatted_history, "Formatted history should contain 'ANSWER:' prefix"
|
||||
|
||||
from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
|
||||
persist_sessions_in_knowledge_graph_pipeline,
|
||||
)
|
||||
|
||||
logger.info("Starting persist_sessions_in_knowledge_graph tests")
|
||||
|
||||
await persist_sessions_in_knowledge_graph_pipeline(
|
||||
user=user,
|
||||
session_ids=[session_id_1, session_id_2],
|
||||
dataset=dataset_name,
|
||||
run_in_background=False,
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
||||
type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
|
||||
|
||||
"Tests the correct number of NodeSet nodes after session persistence"
|
||||
assert type_counts.get("NodeSet", 0) == 1, (
|
||||
f"Number of NodeSets in the graph is incorrect, found {type_counts.get('NodeSet', 0)} but there should be exactly 1."
|
||||
)
|
||||
|
||||
"Tests the correct number of DocumentChunk nodes after session persistence"
|
||||
assert type_counts.get("DocumentChunk", 0) == 4, (
|
||||
f"Number of DocumentChunk ndoes in the graph is incorrect, found {type_counts.get('DocumentChunk', 0)} but there should be exactly 4 (2 original documents, 2 sessions)."
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
collection_size = await vector_engine.search(
|
||||
collection_name="DocumentChunk_text",
|
||||
query_text="test",
|
||||
limit=1000,
|
||||
)
|
||||
assert len(collection_size) == 4, (
|
||||
f"DocumentChunk_text collection should have exactly 4 embeddings, found {len(collection_size)}"
|
||||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,33 @@ async def test_edge_ingestion():
|
|||
|
||||
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
|
||||
|
||||
"Tests edge_text presence and format"
|
||||
contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
|
||||
assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
|
||||
|
||||
edge_properties = contains_edges[0][3]
|
||||
assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
|
||||
|
||||
edge_text = edge_properties["edge_text"]
|
||||
assert "relationship_name: contains" in edge_text, (
|
||||
f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
|
||||
)
|
||||
assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}"
|
||||
assert "entity_description:" in edge_text, (
|
||||
f"Expected 'entity_description:' in edge_text, got: {edge_text}"
|
||||
)
|
||||
|
||||
all_edge_texts = [
|
||||
edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3]
|
||||
]
|
||||
expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"]
|
||||
found_entity = any(
|
||||
any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts
|
||||
)
|
||||
assert found_entity, (
|
||||
f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}"
|
||||
)
|
||||
|
||||
"Tests the presence of basic nested edges"
|
||||
for basic_nested_edge in basic_nested_edges:
|
||||
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ async def main():
|
|||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
data=[{}],
|
||||
dataset="feedback_enrichment_test_memify",
|
||||
dataset=dataset_name,
|
||||
)
|
||||
|
||||
nodes_after, edges_after = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -90,15 +90,17 @@ async def main():
|
|||
)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?"
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What information do you contain?",
|
||||
dataset_ids=[pipeline_run_obj.dataset_id],
|
||||
)
|
||||
assert "Mark" in search_results[0], (
|
||||
assert "Mark" in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, no mention of Mark in search results"
|
||||
)
|
||||
assert "Cindy" in search_results[0], (
|
||||
assert "Cindy" in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, no mention of Cindy in search results"
|
||||
)
|
||||
assert "Artificial intelligence" not in search_results[0], (
|
||||
assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, Artificial intelligence still mentioned in search results"
|
||||
)
|
||||
|
||||
|
|
|
|||
62
cognee/tests/test_load.py
Normal file
62
cognee/tests/test_load.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def process_and_search(num_of_searches):
|
||||
start_time = time.time()
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
cognee.search(
|
||||
query_text="Tell me about the document", query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
for _ in range(num_of_searches)
|
||||
]
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
async def main():
|
||||
data_directory_path = os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_load")
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
cognee_directory_path = os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_load")
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
num_of_pdfs = 10
|
||||
num_of_reps = 5
|
||||
upper_boundary_minutes = 10
|
||||
average_minutes = 8
|
||||
|
||||
recorded_times = []
|
||||
for _ in range(num_of_reps):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
s3_input = "s3://cognee-test-load-s3-bucket"
|
||||
await cognee.add(s3_input)
|
||||
|
||||
recorded_times.append(await process_and_search(num_of_pdfs))
|
||||
|
||||
average_recorded_time = sum(recorded_times) / len(recorded_times)
|
||||
|
||||
assert average_recorded_time <= average_minutes * 60
|
||||
|
||||
assert all(rec_time <= upper_boundary_minutes * 60 for rec_time in recorded_times)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str:
|
|||
|
||||
|
||||
async def setup_test_db():
|
||||
# Disable backend access control to migrate relational data
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,13 @@ async def main():
|
|||
assert len(search_results) == 1, (
|
||||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
text = search_results[0]
|
||||
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if backend_access_control_enabled():
|
||||
text = search_results[0]["search_result"][0]
|
||||
else:
|
||||
text = search_results[0]
|
||||
assert isinstance(text, str), f"{name}: element should be a string"
|
||||
assert text.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in text.lower(), (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
|
@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
|
|||
from types import SimpleNamespace
|
||||
import importlib
|
||||
|
||||
from cognee.api.client import app
|
||||
|
||||
|
||||
# Fixtures for reuse across test classes
|
||||
@pytest.fixture
|
||||
|
|
@ -32,6 +31,10 @@ def mock_authenticated_user():
|
|||
)
|
||||
|
||||
|
||||
# To turn off authentication we need to set the environment variable before importing the module
|
||||
# Also both require_authentication and backend access control must be false
|
||||
os.environ["REQUIRE_AUTHENTICATION"] = "false"
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
||||
|
||||
|
||||
|
|
@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
|
@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
|
|
@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
|
|||
# The exact error message may vary depending on the actual database connection
|
||||
# The important thing is that we get a 500 error when user creation fails
|
||||
|
||||
def test_current_environment_configuration(self):
|
||||
def test_current_environment_configuration(self, client):
|
||||
"""Test that current environment configuration is working properly."""
|
||||
# This tests the actual module state without trying to change it
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from cognee.tasks.storage.index_data_points import index_data_points
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class TestDataPoint(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_data_points_calls_vector_engine():
|
||||
"""Test that index_data_points creates vector index and indexes data."""
|
||||
data_points = [TestDataPoint(name="test1")]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
|
||||
|
||||
with patch.dict(
|
||||
index_data_points.__globals__,
|
||||
{"get_vector_engine": lambda: mock_vector_engine},
|
||||
):
|
||||
await index_data_points(data_points)
|
||||
|
||||
assert mock_vector_engine.create_vector_index.await_count >= 1
|
||||
assert mock_vector_engine.index_data_points.await_count >= 1
|
||||
|
|
@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_graph_edges_success():
|
||||
"""Test that index_graph_edges uses the index datapoints and creates vector index."""
|
||||
# Create the mocks for the graph and vector engines.
|
||||
"""Test that index_graph_edges retrieves edges and delegates to index_data_points."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (
|
||||
None,
|
||||
|
|
@ -15,26 +14,23 @@ async def test_index_graph_edges_success():
|
|||
[{"relationship_name": "rel2"}],
|
||||
],
|
||||
)
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
|
||||
mock_index_data_points = AsyncMock()
|
||||
|
||||
# Patch the globals of the function so that when it does:
|
||||
# vector_engine = get_vector_engine()
|
||||
# graph_engine = await get_graph_engine()
|
||||
# it uses the mocked versions.
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
"index_data_points": mock_index_data_points,
|
||||
},
|
||||
):
|
||||
await index_graph_edges()
|
||||
|
||||
# Assertions on the mock calls.
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
assert mock_vector_engine.create_vector_index.await_count == 1
|
||||
assert mock_vector_engine.index_data_points.await_count == 1
|
||||
mock_index_data_points.assert_awaited_once()
|
||||
|
||||
call_args = mock_index_data_points.call_args[0][0]
|
||||
assert len(call_args) == 2
|
||||
assert all(hasattr(item, "relationship_name") for item in call_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships():
|
|||
"""Test that index_graph_edges handles empty relationships correctly."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (None, [])
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_index_data_points = AsyncMock()
|
||||
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
"index_data_points": mock_index_data_points,
|
||||
},
|
||||
):
|
||||
await index_graph_edges()
|
||||
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
mock_vector_engine.create_vector_index.assert_not_awaited()
|
||||
mock_vector_engine.index_data_points.assert_not_awaited()
|
||||
mock_index_data_points.assert_awaited_once()
|
||||
|
||||
call_args = mock_index_data_points.call_args[0][0]
|
||||
assert len(call_args) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
248
cognee/tests/unit/modules/chunking/test_text_chunker.py
Normal file
248
cognee/tests/unit/modules/chunking/test_text_chunker.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Unit tests for TextChunker and TextChunkerWithOverlap behavioral equivalence."""
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
|
||||
|
||||
@pytest.fixture(params=["TextChunker", "TextChunkerWithOverlap"])
|
||||
def chunker_class(request):
|
||||
"""Parametrize tests to run against both implementations."""
|
||||
return TextChunker if request.param == "TextChunker" else TextChunkerWithOverlap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_text_generator():
|
||||
"""Factory for async text generators."""
|
||||
|
||||
def _factory(*texts):
|
||||
async def gen():
|
||||
for text in texts:
|
||||
yield text
|
||||
|
||||
return gen
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
async def collect_chunks(chunker):
|
||||
"""Consume async generator and return list of chunks."""
|
||||
chunks = []
|
||||
async for chunk in chunker.read():
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_input_produces_no_chunks(chunker_class, make_text_generator):
|
||||
"""Empty input should yield no chunks."""
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator("")
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 0, "Empty input should produce no chunks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_input_emits_single_chunk(chunker_class, make_text_generator):
|
||||
"""Whitespace-only input should produce exactly one chunk with unchanged text."""
|
||||
whitespace_text = " \n\t \r\n "
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(whitespace_text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 1, "Whitespace-only input should produce exactly one chunk"
|
||||
assert chunks[0].text == whitespace_text, "Chunk text should equal input (whitespace preserved)"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_paragraph_below_limit_emits_one_chunk(chunker_class, make_text_generator):
|
||||
"""Single paragraph below limit should emit exactly one chunk."""
|
||||
text = "This is a short paragraph."
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=512)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 1, "Single short paragraph should produce exactly one chunk"
|
||||
assert chunks[0].text == text, "Chunk text should match input"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[0].chunk_size > 0, "Chunk should have positive size"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_paragraph_gets_emitted_as_a_single_chunk(
|
||||
chunker_class, make_text_generator
|
||||
):
|
||||
"""Oversized paragraph from chunk_by_paragraph should be emitted as single chunk."""
|
||||
text = ("A" * 1500) + ". Next sentence."
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=50)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 2, "Should produce 2 chunks (oversized paragraph + next sentence)"
|
||||
assert chunks[0].chunk_size > 50, "First chunk should be oversized"
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overflow_on_next_paragraph_emits_separate_chunk(chunker_class, make_text_generator):
|
||||
"""First paragraph near limit plus small paragraph should produce two separate chunks."""
|
||||
first_para = " ".join(["word"] * 5)
|
||||
second_para = "Short text."
|
||||
text = first_para + " " + second_para
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=10)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 2, "Should produce 2 chunks due to overflow"
|
||||
assert chunks[0].text.strip() == first_para, "First chunk should contain only first paragraph"
|
||||
assert chunks[1].text.strip() == second_para, (
|
||||
"Second chunk should contain only second paragraph"
|
||||
)
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_paragraphs_batch_correctly(chunker_class, make_text_generator):
|
||||
"""Multiple small paragraphs should batch together with joiner spaces counted."""
|
||||
paragraphs = [" ".join(["word"] * 12) for _ in range(40)]
|
||||
text = " ".join(paragraphs)
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=49)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 20, (
|
||||
"Should batch paragraphs (2 per chunk: 12 words × 2 tokens = 24, 24 + 1 joiner + 24 = 49)"
|
||||
)
|
||||
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
|
||||
"Chunk indices should be sequential"
|
||||
)
|
||||
all_text = " ".join(chunk.text.strip() for chunk in chunks)
|
||||
expected_text = " ".join(paragraphs)
|
||||
assert all_text == expected_text, "All paragraph text should be preserved with correct spacing"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_large_and_small_paragraphs_dont_batch(
|
||||
chunker_class, make_text_generator
|
||||
):
|
||||
"""Alternating near-max and small paragraphs should each become separate chunks."""
|
||||
large1 = "word" * 15 + "."
|
||||
small1 = "Short."
|
||||
large2 = "word" * 15 + "."
|
||||
small2 = "Tiny."
|
||||
text = large1 + " " + small1 + " " + large2 + " " + small2
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
max_chunk_size = 10
|
||||
get_text = make_text_generator(text)
|
||||
chunker = chunker_class(document, get_text, max_chunk_size=max_chunk_size)
|
||||
chunks = await collect_chunks(chunker)
|
||||
|
||||
assert len(chunks) == 4, "Should produce multiple chunks"
|
||||
assert all(c.chunk_index == i for i, c in enumerate(chunks)), (
|
||||
"Chunk indices should be sequential"
|
||||
)
|
||||
assert chunks[0].chunk_size > max_chunk_size, (
|
||||
"First chunk should be oversized (large paragraph)"
|
||||
)
|
||||
assert chunks[1].chunk_size <= max_chunk_size, "Second chunk should be small (small paragraph)"
|
||||
assert chunks[2].chunk_size > max_chunk_size, (
|
||||
"Third chunk should be oversized (large paragraph)"
|
||||
)
|
||||
assert chunks[3].chunk_size <= max_chunk_size, "Fourth chunk should be small (small paragraph)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_indices_and_ids_are_deterministic(chunker_class, make_text_generator):
|
||||
"""Running chunker twice on identical input should produce identical indices and IDs."""
|
||||
sentence1 = "one " * 4 + ". "
|
||||
sentence2 = "two " * 4 + ". "
|
||||
sentence3 = "one " * 4 + ". "
|
||||
sentence4 = "two " * 4 + ". "
|
||||
text = sentence1 + sentence2 + sentence3 + sentence4
|
||||
doc_id = uuid4()
|
||||
max_chunk_size = 20
|
||||
|
||||
document1 = Document(
|
||||
id=doc_id,
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text1 = make_text_generator(text)
|
||||
chunker1 = chunker_class(document1, get_text1, max_chunk_size=max_chunk_size)
|
||||
chunks1 = await collect_chunks(chunker1)
|
||||
|
||||
document2 = Document(
|
||||
id=doc_id,
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text2 = make_text_generator(text)
|
||||
chunker2 = chunker_class(document2, get_text2, max_chunk_size=max_chunk_size)
|
||||
chunks2 = await collect_chunks(chunker2)
|
||||
|
||||
assert len(chunks1) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
|
||||
assert len(chunks2) == 2, "Should produce exactly 2 chunks (4 sentences, 2 per chunk)"
|
||||
assert [c.chunk_index for c in chunks1] == [0, 1], "First run indices should be [0, 1]"
|
||||
assert [c.chunk_index for c in chunks2] == [0, 1], "Second run indices should be [0, 1]"
|
||||
assert chunks1[0].id == chunks2[0].id, "First chunk ID should be deterministic"
|
||||
assert chunks1[1].id == chunks2[1].id, "Second chunk ID should be deterministic"
|
||||
assert chunks1[0].id != chunks1[1].id, "Chunk IDs should be unique within a run"
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""Unit tests for TextChunkerWithOverlap overlap behavior."""
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.text_chunker_with_overlap import TextChunkerWithOverlap
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_text_generator():
|
||||
"""Factory for async text generators."""
|
||||
|
||||
def _factory(*texts):
|
||||
async def gen():
|
||||
for text in texts:
|
||||
yield text
|
||||
|
||||
return gen
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_controlled_chunk_data():
|
||||
"""Factory for controlled chunk_data generators."""
|
||||
|
||||
def _factory(*sentences, chunk_size_per_sentence=10):
|
||||
def _chunk_data(text):
|
||||
return [
|
||||
{
|
||||
"text": sentence,
|
||||
"chunk_size": chunk_size_per_sentence,
|
||||
"cut_type": "sentence",
|
||||
"chunk_id": uuid4(),
|
||||
}
|
||||
for sentence in sentences
|
||||
]
|
||||
|
||||
return _chunk_data
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_half_overlap_preserves_content_across_chunks(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 50% overlap, consecutive chunks should share half their content."""
|
||||
s1 = "one"
|
||||
s2 = "two"
|
||||
s3 = "three"
|
||||
s4 = "four"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.5,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 3, "Should produce exactly 3 chunks (s1+s2, s2+s3, s3+s4)"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1, 2], "Chunk indices should be [0, 1, 2]"
|
||||
assert "one" in chunks[0].text and "two" in chunks[0].text, "Chunk 0 should contain s1 and s2"
|
||||
assert "two" in chunks[1].text and "three" in chunks[1].text, (
|
||||
"Chunk 1 should contain s2 (overlap) and s3"
|
||||
)
|
||||
assert "three" in chunks[2].text and "four" in chunks[2].text, (
|
||||
"Chunk 2 should contain s3 (overlap) and s4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_overlap_produces_no_duplicate_content(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 0% overlap, no content should appear in multiple chunks."""
|
||||
s1 = "one"
|
||||
s2 = "two"
|
||||
s3 = "three"
|
||||
s4 = "four"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.0,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks (s1+s2, s3+s4)"
|
||||
assert "one" in chunks[0].text and "two" in chunks[0].text, (
|
||||
"First chunk should contain s1 and s2"
|
||||
)
|
||||
assert "three" in chunks[1].text and "four" in chunks[1].text, (
|
||||
"Second chunk should contain s3 and s4"
|
||||
)
|
||||
assert "two" not in chunks[1].text and "three" not in chunks[0].text, (
|
||||
"No overlap: end of chunk 0 should not appear in chunk 1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_overlap_ratio_creates_minimal_overlap(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 25% overlap ratio, chunks should have minimal overlap."""
|
||||
s1 = "alpha"
|
||||
s2 = "beta"
|
||||
s3 = "gamma"
|
||||
s4 = "delta"
|
||||
s5 = "epsilon"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=40,
|
||||
chunk_overlap_ratio=0.25,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
|
||||
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
|
||||
"Chunk 0 should contain s1 through s4"
|
||||
)
|
||||
assert s4 in chunks[1].text and s5 in chunks[1].text, (
|
||||
"Chunk 1 should contain overlap s4 and new content s5"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_overlap_ratio_creates_significant_overlap(
|
||||
make_text_generator, make_controlled_chunk_data
|
||||
):
|
||||
"""With 75% overlap ratio, consecutive chunks should share most content."""
|
||||
s1 = "red"
|
||||
s2 = "blue"
|
||||
s3 = "green"
|
||||
s4 = "yellow"
|
||||
s5 = "purple"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, s3, s4, s5, chunk_size_per_sentence=5)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.75,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 2, "Should produce exactly 2 chunks with 75% overlap"
|
||||
assert [c.chunk_index for c in chunks] == [0, 1], "Chunk indices should be [0, 1]"
|
||||
assert all(token in chunks[0].text for token in [s1, s2, s3, s4]), (
|
||||
"Chunk 0 should contain s1, s2, s3, s4"
|
||||
)
|
||||
assert all(token in chunks[1].text for token in [s2, s3, s4, s5]), (
|
||||
"Chunk 1 should contain s2, s3, s4 (overlap) and s5"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_no_dangling_overlap(make_text_generator, make_controlled_chunk_data):
|
||||
"""Text that fits in one chunk should produce exactly one chunk, no overlap artifact."""
|
||||
s1 = "alpha"
|
||||
s2 = "beta"
|
||||
text = "dummy"
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
get_text = make_text_generator(text)
|
||||
get_chunk_data = make_controlled_chunk_data(s1, s2, chunk_size_per_sentence=10)
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=20,
|
||||
chunk_overlap_ratio=0.5,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 1, (
|
||||
"Should produce exactly 1 chunk when content fits within max_chunk_size"
|
||||
)
|
||||
assert chunks[0].chunk_index == 0, "Single chunk should have index 0"
|
||||
assert "alpha" in chunks[0].text and "beta" in chunks[0].text, (
|
||||
"Single chunk should contain all content"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paragraph_chunking_with_overlap(make_text_generator):
|
||||
"""Test that chunk_by_paragraph integration produces 25% overlap between chunks."""
|
||||
|
||||
def mock_get_embedding_engine():
|
||||
class MockEngine:
|
||||
tokenizer = None
|
||||
|
||||
return MockEngine()
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
max_chunk_size = 20
|
||||
overlap_ratio = 0.25 # 5 token overlap
|
||||
paragraph_max_size = int(0.5 * overlap_ratio * max_chunk_size) # = 2
|
||||
|
||||
text = (
|
||||
"A0 A1. A2 A3. A4 A5. A6 A7. A8 A9. " # 10 tokens (0-9)
|
||||
"B0 B1. B2 B3. B4 B5. B6 B7. B8 B9. " # 10 tokens (10-19)
|
||||
"C0 C1. C2 C3. C4 C5. C6 C7. C8 C9. " # 10 tokens (20-29)
|
||||
"D0 D1. D2 D3. D4 D5. D6 D7. D8 D9. " # 10 tokens (30-39)
|
||||
"E0 E1. E2 E3. E4 E5. E6 E7. E8 E9." # 10 tokens (40-49)
|
||||
)
|
||||
|
||||
document = Document(
|
||||
id=uuid4(),
|
||||
name="test_document",
|
||||
raw_data_location="/test/path",
|
||||
external_metadata=None,
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
get_text = make_text_generator(text)
|
||||
|
||||
def get_chunk_data(text_input):
|
||||
return chunk_by_paragraph(
|
||||
text_input, max_chunk_size=paragraph_max_size, batch_paragraphs=True
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
):
|
||||
chunker = TextChunkerWithOverlap(
|
||||
document,
|
||||
get_text,
|
||||
max_chunk_size=max_chunk_size,
|
||||
chunk_overlap_ratio=overlap_ratio,
|
||||
get_chunk_data=get_chunk_data,
|
||||
)
|
||||
chunks = [chunk async for chunk in chunker.read()]
|
||||
|
||||
assert len(chunks) == 3, f"Should produce exactly 3 chunks, got {len(chunks)}"
|
||||
|
||||
assert chunks[0].chunk_index == 0, "First chunk should have index 0"
|
||||
assert chunks[1].chunk_index == 1, "Second chunk should have index 1"
|
||||
assert chunks[2].chunk_index == 2, "Third chunk should have index 2"
|
||||
|
||||
assert "A0" in chunks[0].text, "Chunk 0 should start with A0"
|
||||
assert "A9" in chunks[0].text, "Chunk 0 should contain A9"
|
||||
assert "B0" in chunks[0].text, "Chunk 0 should contain B0"
|
||||
assert "B9" in chunks[0].text, "Chunk 0 should contain up to B9 (20 tokens)"
|
||||
|
||||
assert "B" in chunks[1].text, "Chunk 1 should have overlap from B section"
|
||||
assert "C" in chunks[1].text, "Chunk 1 should contain C section"
|
||||
assert "D" in chunks[1].text, "Chunk 1 should contain D section"
|
||||
|
||||
assert "D" in chunks[2].text, "Chunk 2 should have overlap from D section"
|
||||
assert "E0" in chunks[2].text, "Chunk 2 should contain E0"
|
||||
assert "E9" in chunks[2].text, "Chunk 2 should end with E9"
|
||||
|
||||
chunk_0_end_words = chunks[0].text.split()[-4:]
|
||||
chunk_1_words = chunks[1].text.split()
|
||||
overlap_0_1 = any(word in chunk_1_words for word in chunk_0_end_words)
|
||||
assert overlap_0_1, (
|
||||
f"No overlap detected between chunks 0 and 1. "
|
||||
f"Chunk 0 ends with: {chunk_0_end_words}, "
|
||||
f"Chunk 1 starts with: {chunk_1_words[:6]}"
|
||||
)
|
||||
|
||||
chunk_1_end_words = chunks[1].text.split()[-4:]
|
||||
chunk_2_words = chunks[2].text.split()
|
||||
overlap_1_2 = any(word in chunk_2_words for word in chunk_1_end_words)
|
||||
assert overlap_1_2, (
|
||||
f"No overlap detected between chunks 1 and 2. "
|
||||
f"Chunk 1 ends with: {chunk_1_end_words}, "
|
||||
f"Chunk 2 starts with: {chunk_2_words[:6]}"
|
||||
)
|
||||
111
cognee/tests/unit/modules/memify_tasks/test_cognify_session.py
Normal file
111
cognee/tests/unit/modules/memify_tasks/test_cognify_session.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.tasks.memify.cognify_session import cognify_session
|
||||
from cognee.exceptions import CogneeValidationError, CogneeSystemError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_success():
|
||||
"""Test successful cognification of session data."""
|
||||
session_data = (
|
||||
"Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
||||
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
||||
):
|
||||
await cognify_session(session_data, dataset_id="123")
|
||||
|
||||
mock_add.assert_called_once_with(
|
||||
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
|
||||
)
|
||||
mock_cognify.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_empty_string():
|
||||
"""Test cognification fails with empty string."""
|
||||
with pytest.raises(CogneeValidationError) as exc_info:
|
||||
await cognify_session("")
|
||||
|
||||
assert "Session data cannot be empty" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_whitespace_string():
|
||||
"""Test cognification fails with whitespace-only string."""
|
||||
with pytest.raises(CogneeValidationError) as exc_info:
|
||||
await cognify_session(" \n\t ")
|
||||
|
||||
assert "Session data cannot be empty" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_none_data():
|
||||
"""Test cognification fails with None data."""
|
||||
with pytest.raises(CogneeValidationError) as exc_info:
|
||||
await cognify_session(None)
|
||||
|
||||
assert "Session data cannot be empty" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_add_failure():
|
||||
"""Test cognification handles cognee.add failure."""
|
||||
session_data = "Session ID: test\n\nQuestion: test?"
|
||||
|
||||
with (
|
||||
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
||||
patch("cognee.cognify", new_callable=AsyncMock),
|
||||
):
|
||||
mock_add.side_effect = Exception("Add operation failed")
|
||||
|
||||
with pytest.raises(CogneeSystemError) as exc_info:
|
||||
await cognify_session(session_data)
|
||||
|
||||
assert "Failed to cognify session data" in str(exc_info.value)
|
||||
assert "Add operation failed" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_cognify_failure():
|
||||
"""Test cognification handles cognify failure."""
|
||||
session_data = "Session ID: test\n\nQuestion: test?"
|
||||
|
||||
with (
|
||||
patch("cognee.add", new_callable=AsyncMock),
|
||||
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
||||
):
|
||||
mock_cognify.side_effect = Exception("Cognify operation failed")
|
||||
|
||||
with pytest.raises(CogneeSystemError) as exc_info:
|
||||
await cognify_session(session_data)
|
||||
|
||||
assert "Failed to cognify session data" in str(exc_info.value)
|
||||
assert "Cognify operation failed" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_re_raises_validation_error():
|
||||
"""Test that CogneeValidationError is re-raised as-is."""
|
||||
with pytest.raises(CogneeValidationError):
|
||||
await cognify_session("")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognify_session_with_special_characters():
|
||||
"""Test cognification with special characters."""
|
||||
session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!"
|
||||
|
||||
with (
|
||||
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
||||
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
||||
):
|
||||
await cognify_session(session_data, dataset_id="123")
|
||||
|
||||
mock_add.assert_called_once_with(
|
||||
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
|
||||
)
|
||||
mock_cognify.assert_called_once()
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import sys
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from cognee.tasks.memify.extract_user_sessions import extract_user_sessions
|
||||
from cognee.exceptions import CogneeSystemError
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
# Get the actual module object (not the function) for patching
|
||||
extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a mock user."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "test-user-123"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qa_data():
|
||||
"""Create mock Q&A data."""
|
||||
return [
|
||||
{
|
||||
"question": "What is cognee?",
|
||||
"context": "context about cognee",
|
||||
"answer": "Cognee is a knowledge graph solution",
|
||||
"time": "2025-01-01T12:00:00",
|
||||
},
|
||||
{
|
||||
"question": "How does it work?",
|
||||
"context": "how it works context",
|
||||
"answer": "It processes data and creates graphs",
|
||||
"time": "2025-01-01T12:05:00",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_success(mock_user, mock_qa_data):
|
||||
"""Test successful extraction of sessions."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions([{}], session_ids=["test_session"]):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 1
|
||||
assert "Session ID: test_session" in sessions[0]
|
||||
assert "Question: What is cognee?" in sessions[0]
|
||||
assert "Answer: Cognee is a knowledge graph solution" in sessions[0]
|
||||
assert "Question: How does it work?" in sessions[0]
|
||||
assert "Answer: It processes data and creates graphs" in sessions[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data):
|
||||
"""Test extraction of multiple sessions."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 2
|
||||
assert mock_cache_engine.get_all_qas.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_no_data(mock_user, mock_qa_data):
|
||||
"""Test extraction handles empty data parameter."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions(None, session_ids=["test_session"]):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_no_session_ids(mock_user):
|
||||
"""Test extraction handles no session IDs provided."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions([{}], session_ids=None):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 0
|
||||
mock_cache_engine.get_all_qas.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_empty_qa_data(mock_user):
|
||||
"""Test extraction handles empty Q&A data."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
mock_cache_engine.get_all_qas.return_value = []
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions([{}], session_ids=["empty_session"]):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data):
|
||||
"""Test extraction continues on cache error for specific session."""
|
||||
mock_cache_engine = AsyncMock()
|
||||
mock_cache_engine.get_all_qas.side_effect = [
|
||||
mock_qa_data,
|
||||
Exception("Cache error"),
|
||||
mock_qa_data,
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
||||
patch.object(
|
||||
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
||||
),
|
||||
):
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
sessions = []
|
||||
async for session in extract_user_sessions(
|
||||
[{}], session_ids=["session1", "session2", "session3"]
|
||||
):
|
||||
sessions.append(session)
|
||||
|
||||
assert len(sessions) == 2
|
||||
|
|
@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
|
|||
# REQUIRE_AUTHENTICATION should be a boolean
|
||||
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||
|
||||
# Currently should be False (optional authentication)
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
|
||||
class TestConditionalAuthenticationEnvironmentVariables:
|
||||
"""Test environment variable handling."""
|
||||
|
||||
def test_require_authentication_default_false(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment - module will see empty environment
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
importlib.invalidate_caches()
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_true(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
|
||||
|
|
@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
|
|||
|
||||
assert REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_false_explicit(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_case_insensitive(self):
|
||||
"""Test that environment variable parsing is case insensitive when imported."""
|
||||
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
|
||||
|
||||
for case in test_cases:
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
expected = case.lower() == "true"
|
||||
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
|
||||
|
||||
def test_current_require_authentication_value(self):
|
||||
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
# The module-level variable should currently be False (set at import time)
|
||||
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
|
||||
class TestConditionalAuthenticationEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ services:
|
|||
- DEBUG=false # Change to true if debugging
|
||||
- HOST=0.0.0.0
|
||||
- ENVIRONMENT=local
|
||||
- LOG_LEVEL=ERROR
|
||||
- LOG_LEVEL=INFO
|
||||
extra_hosts:
|
||||
# Allows the container to reach your local machine using "host.docker.internal" instead of "localhost"
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
|
|
|||
|
|
@ -43,10 +43,10 @@ sleep 2
|
|||
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
||||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
exec debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app
|
||||
else
|
||||
gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload --access-logfile - --error-logfile - cognee.api.client:app
|
||||
fi
|
||||
else
|
||||
gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app
|
||||
exec gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error --access-logfile - --error-logfile - cognee.api.client:app
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ async def run_procurement_example():
|
|||
for q in questions:
|
||||
print(f"Question: \n{q}")
|
||||
results = await procurement_system.search_memory(q, search_categories=[category])
|
||||
top_answer = results[category][0]
|
||||
top_answer = results[category][0]["search_result"][0]
|
||||
print(f"Answer: \n{top_answer.strip()}\n")
|
||||
research_notes[category].append({"question": q, "answer": top_answer})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
|
|
@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
|||
|
||||
|
||||
async def main(repo_path, include_docs):
|
||||
# Disable permissions feature for this example
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
run_status = False
|
||||
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||
run_status = run_status
|
||||
|
|
|
|||
98
examples/python/conversation_session_persistence_example.py
Normal file
98
examples/python/conversation_session_persistence_example.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee import visualize_graph
|
||||
from cognee.memify_pipelines.persist_sessions_in_knowledge_graph import (
|
||||
persist_sessions_in_knowledge_graph_pipeline,
|
||||
)
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("conversation_session_persistence_example")
|
||||
|
||||
|
||||
async def main():
|
||||
# NOTE: CACHING has to be enabled for this example to work
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
text_1 = "Cognee is a solution that can build knowledge graph from text, creating an AI memory system"
|
||||
text_2 = "Germany is a country located next to the Netherlands"
|
||||
|
||||
await cognee.add([text_1, text_2])
|
||||
await cognee.cognify()
|
||||
|
||||
question = "What can I use to create a knowledge graph?"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=question,
|
||||
)
|
||||
print("\nSession ID: default_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
question = "You sure about that?"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=question
|
||||
)
|
||||
print("\nSession ID: default_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
question = "This is awesome!"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=question
|
||||
)
|
||||
print("\nSession ID: default_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
question = "Where is Germany?"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=question,
|
||||
session_id="different_session",
|
||||
)
|
||||
print("\nSession ID: different_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
question = "Next to which country again?"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=question,
|
||||
session_id="different_session",
|
||||
)
|
||||
print("\nSession ID: different_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
question = "So you remember everything I asked from you?"
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=question,
|
||||
session_id="different_session",
|
||||
)
|
||||
print("\nSession ID: different_session")
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {search_results}\n")
|
||||
|
||||
session_ids_to_persist = ["default_session", "different_session"]
|
||||
default_user = await get_default_user()
|
||||
|
||||
await persist_sessions_in_knowledge_graph_pipeline(
|
||||
user=default_user,
|
||||
session_ids=session_ids_to_persist,
|
||||
)
|
||||
|
||||
await visualize_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
|||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
data=[{}], # A placeholder to prevent fetching the entire graph
|
||||
dataset="feedback_enrichment_minimal",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ async def main():
|
|||
)
|
||||
|
||||
print("Coding rules created by memify:")
|
||||
for coding_rule in coding_rules:
|
||||
for coding_rule in coding_rules[0]["search_result"][0]:
|
||||
print("- " + coding_rule)
|
||||
|
||||
# Visualize new graph with added memify context
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
|||
|
||||
|
||||
async def main():
|
||||
# Disable backend access control to migrate relational data
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
# Clean all data stored in Cognee
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
|
|||
84
examples/python/run_custom_pipeline_example.py
Normal file
84
examples/python/run_custom_pipeline_example.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.modules.pipelines import Task
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
# Prerequisites:
|
||||
# 1. Copy `.env.template` and rename it to `.env`.
|
||||
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
|
||||
# LLM_API_KEY = "your_key_here"
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
print("Resetting cognee data...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("Data reset complete.\n")
|
||||
|
||||
# Create relational database and tables
|
||||
await setup()
|
||||
|
||||
# cognee knowledge graph will be created based on this text
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
print("Adding text to cognee:")
|
||||
print(text.strip())
|
||||
|
||||
# Let's recreate the cognee add pipeline through the custom pipeline framework
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
# Values for tasks need to be filled before calling the pipeline
|
||||
add_tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
Task(
|
||||
ingest_data,
|
||||
"main_dataset",
|
||||
user,
|
||||
),
|
||||
]
|
||||
# Forward tasks to custom pipeline along with data and user information
|
||||
await cognee.run_custom_pipeline(
|
||||
tasks=add_tasks, data=text, user=user, dataset="main_dataset", pipeline_name="add_pipeline"
|
||||
)
|
||||
print("Text added successfully.\n")
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
from cognee.api.v1.cognify.cognify import get_default_tasks
|
||||
|
||||
cognify_tasks = await get_default_tasks(user=user)
|
||||
print("Recreating existing cognify pipeline in custom pipeline to create knowledge graph...\n")
|
||||
await cognee.run_custom_pipeline(
|
||||
tasks=cognify_tasks, user=user, dataset="main_dataset", pipeline_name="cognify_pipeline"
|
||||
)
|
||||
print("Cognify process complete.\n")
|
||||
|
||||
query_text = "Tell me about NLP"
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||
)
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
@ -59,14 +59,6 @@ async def main():
|
|||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
# Example output:
|
||||
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
|
||||
# (...)
|
||||
# It represents nodes and relationships in the knowledge graph:
|
||||
# - The first element is the source node (e.g., 'natural language processing').
|
||||
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
|
||||
# - The third element is the target node (e.g., 'computer science').
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue