Merge branch 'dev' into feature/cog-3165-add-load-tests

This commit is contained in:
Andrej Milicevic 2025-10-21 13:39:24 +02:00
commit 4ff0b3407e
52 changed files with 3318 additions and 3070 deletions

View file

@ -28,11 +28,10 @@ EMBEDDING_ENDPOINT=""
EMBEDDING_API_VERSION="" EMBEDDING_API_VERSION=""
EMBEDDING_DIMENSIONS=3072 EMBEDDING_DIMENSIONS=3072
EMBEDDING_MAX_TOKENS=8191 EMBEDDING_MAX_TOKENS=8191
EMBEDDING_BATCH_SIZE=36
# If embedding key is not provided same key set for LLM_API_KEY will be used # If embedding key is not provided same key set for LLM_API_KEY will be used
#EMBEDDING_API_KEY="your_api_key" #EMBEDDING_API_KEY="your_api_key"
# Note: OpenAI support up to 2048 elements and Gemini supports a maximum of 100 elements in an embedding batch,
# Cognee sets the optimal batch size for OpenAI and Gemini, but a custom size can be defined if necessary for other models
#EMBEDDING_BATCH_SIZE=2048
# If using BAML structured output these env variables will be used # If using BAML structured output these env variables will be used
BAML_LLM_PROVIDER=openai BAML_LLM_PROVIDER=openai
@ -248,10 +247,10 @@ LITELLM_LOG="ERROR"
#LLM_PROVIDER="ollama" #LLM_PROVIDER="ollama"
#LLM_ENDPOINT="http://localhost:11434/v1" #LLM_ENDPOINT="http://localhost:11434/v1"
#EMBEDDING_PROVIDER="ollama" #EMBEDDING_PROVIDER="ollama"
#EMBEDDING_MODEL="avr/sfr-embedding-mistral:latest" #EMBEDDING_MODEL="nomic-embed-text:latest"
#EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings" #EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings"
#EMBEDDING_DIMENSIONS=4096 #EMBEDDING_DIMENSIONS=768
#HUGGINGFACE_TOKENIZER="Salesforce/SFR-Embedding-Mistral" #HUGGINGFACE_TOKENIZER="nomic-ai/nomic-embed-text-v1.5"
########## OpenRouter (also free) ######################################################### ########## OpenRouter (also free) #########################################################

View file

@ -41,4 +41,4 @@ runs:
EXTRA_ARGS="$EXTRA_ARGS --extra $extra" EXTRA_ARGS="$EXTRA_ARGS --extra $extra"
done done
fi fi
uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j $EXTRA_ARGS uv sync --extra api --extra docs --extra evals --extra codegraph --extra ollama --extra dev --extra neo4j --extra redis $EXTRA_ARGS

View file

@ -1,4 +1,6 @@
name: Reusable Integration Tests name: Reusable Integration Tests
permissions:
contents: read
on: on:
workflow_call: workflow_call:
@ -264,3 +266,68 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_edge_ingestion.py run: uv run python ./cognee/tests/test_edge_ingestion.py
run_concurrent_subprocess_access_test:
name: Concurrent Subprocess access test
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 3s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "postgres redis"
- name: Run Concurrent subprocess access test (Kuzu/Lancedb/Postgres)
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
CACHING: true
SHARED_KUZU_LOCK: true
DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1'
DB_PORT: 5432
DB_USERNAME: cognee
DB_PASSWORD: cognee
run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py

View file

@ -110,6 +110,81 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/dynamic_steps_example.py run: uv run python ./examples/python/dynamic_steps_example.py
test-temporal-example:
name: Run Temporal Tests
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Temporal Example
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/temporal_example.py
test-ontology-example:
name: Run Ontology Tests
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Ontology Demo Example
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/ontology_demo_example.py
test-agentic-reasoning:
name: Run Agentic Reasoning Tests
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Agentic Reasoning Example
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/agentic_reasoning_procurement_example.py
test-memify: test-memify:
name: Run Memify Example name: Run Memify Example
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04

View file

@ -9,7 +9,7 @@ on:
python-versions: python-versions:
required: false required: false
type: string type: string
default: '["3.10.x", "3.11.x", "3.12.x"]' default: '["3.10.x", "3.12.x", "3.13.x"]'
secrets: secrets:
LLM_PROVIDER: LLM_PROVIDER:
required: true required: true
@ -193,6 +193,13 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Path setup
if: ${{ matrix.os }} == 'windows-latest'
shell: bash
run: |
PATH=$(printf '%s' "$PATH" | tr ':' $'\n' | grep -vi '/git/usr/bin' | paste -sd: -)
export PATH
- name: Run Soft Deletion Tests - name: Run Soft Deletion Tests
env: env:
ENV: 'dev' ENV: 'dev'

View file

@ -85,7 +85,7 @@ jobs:
needs: [basic-tests, e2e-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml uses: ./.github/workflows/test_different_operating_systems.yml
with: with:
python-versions: '["3.10.x", "3.11.x", "3.12.x"]' python-versions: '["3.10.x", "3.11.x", "3.12.x", "3.13.x"]'
secrets: inherit secrets: inherit
# Matrix-based vector database tests # Matrix-based vector database tests

View file

@ -71,7 +71,7 @@ Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Ext
## Get Started ## Get Started
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a> Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/12Vi9zID-M3fpKpKiaqDBvkk98ElkRPWy?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a>
## About cognee ## About cognee
@ -224,12 +224,12 @@ We now have a paper you can cite:
```bibtex ```bibtex
@misc{markovic2025optimizinginterfaceknowledgegraphs, @misc{markovic2025optimizinginterfaceknowledgegraphs,
title={Optimizing the Interface Between Knowledge Graphs and LLMs for Complex Reasoning}, title={Optimizing the Interface Between Knowledge Graphs and LLMs for Complex Reasoning},
author={Vasilije Markovic and Lazar Obradovic and Laszlo Hajdu and Jovan Pavlovic}, author={Vasilije Markovic and Lazar Obradovic and Laszlo Hajdu and Jovan Pavlovic},
year={2025}, year={2025},
eprint={2505.24478}, eprint={2505.24478},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.AI}, primaryClass={cs.AI},
url={https://arxiv.org/abs/2505.24478}, url={https://arxiv.org/abs/2505.24478},
} }
``` ```

View file

@ -41,6 +41,7 @@ async def add(
extraction_rules: Optional[Dict[str, Any]] = None, extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None, tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None, soup_crawler_config: Optional[BaseModel] = None,
data_per_batch: Optional[int] = 20,
): ):
""" """
Add data to Cognee for knowledge graph processing. Add data to Cognee for knowledge graph processing.
@ -235,6 +236,7 @@ async def add(
vector_db_config=vector_db_config, vector_db_config=vector_db_config,
graph_db_config=graph_db_config, graph_db_config=graph_db_config,
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
): ):
pipeline_run_info = run_info pipeline_run_info = run_info

View file

@ -44,6 +44,7 @@ async def cognify(
graph_model: BaseModel = KnowledgeGraph, graph_model: BaseModel = KnowledgeGraph,
chunker=TextChunker, chunker=TextChunker,
chunk_size: int = None, chunk_size: int = None,
chunks_per_batch: int = None,
config: Config = None, config: Config = None,
vector_db_config: dict = None, vector_db_config: dict = None,
graph_db_config: dict = None, graph_db_config: dict = None,
@ -51,6 +52,7 @@ async def cognify(
incremental_loading: bool = True, incremental_loading: bool = True,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
temporal_cognify: bool = False, temporal_cognify: bool = False,
data_per_batch: int = 20,
): ):
""" """
Transform ingested data into a structured knowledge graph. Transform ingested data into a structured knowledge graph.
@ -105,6 +107,7 @@ async def cognify(
Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2) Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2)
Default limits: ~512-8192 tokens depending on models. Default limits: ~512-8192 tokens depending on models.
Smaller chunks = more granular but potentially fragmented knowledge. Smaller chunks = more granular but potentially fragmented knowledge.
chunks_per_batch: Number of chunks to be processed in a single batch in Cognify tasks.
vector_db_config: Custom vector database configuration for embeddings storage. vector_db_config: Custom vector database configuration for embeddings storage.
graph_db_config: Custom graph database configuration for relationship storage. graph_db_config: Custom graph database configuration for relationship storage.
run_in_background: If True, starts processing asynchronously and returns immediately. run_in_background: If True, starts processing asynchronously and returns immediately.
@ -209,10 +212,18 @@ async def cognify(
} }
if temporal_cognify: if temporal_cognify:
tasks = await get_temporal_tasks(user, chunker, chunk_size) tasks = await get_temporal_tasks(
user=user, chunker=chunker, chunk_size=chunk_size, chunks_per_batch=chunks_per_batch
)
else: else:
tasks = await get_default_tasks( tasks = await get_default_tasks(
user, graph_model, chunker, chunk_size, config, custom_prompt user=user,
graph_model=graph_model,
chunker=chunker,
chunk_size=chunk_size,
config=config,
custom_prompt=custom_prompt,
chunks_per_batch=chunks_per_batch,
) )
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
@ -228,6 +239,7 @@ async def cognify(
graph_db_config=graph_db_config, graph_db_config=graph_db_config,
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
pipeline_name="cognify_pipeline", pipeline_name="cognify_pipeline",
data_per_batch=data_per_batch,
) )
@ -238,6 +250,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
chunk_size: int = None, chunk_size: int = None,
config: Config = None, config: Config = None,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100,
) -> list[Task]: ) -> list[Task]:
if config is None: if config is None:
ontology_config = get_ontology_env_config() ontology_config = get_ontology_env_config()
@ -256,6 +269,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()} "ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
} }
if chunks_per_batch is None:
chunks_per_batch = 100
default_tasks = [ default_tasks = [
Task(classify_documents), Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task(check_permissions_on_dataset, user=user, permissions=["write"]),
@ -269,20 +285,20 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
graph_model=graph_model, graph_model=graph_model,
config=config, config=config,
custom_prompt=custom_prompt, custom_prompt=custom_prompt,
task_config={"batch_size": 10}, task_config={"batch_size": chunks_per_batch},
), # Generate knowledge graphs from the document chunks. ), # Generate knowledge graphs from the document chunks.
Task( Task(
summarize_text, summarize_text,
task_config={"batch_size": 10}, task_config={"batch_size": chunks_per_batch},
), ),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
] ]
return default_tasks return default_tasks
async def get_temporal_tasks( async def get_temporal_tasks(
user: User = None, chunker=TextChunker, chunk_size: int = None user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = 10
) -> list[Task]: ) -> list[Task]:
""" """
Builds and returns a list of temporal processing tasks to be executed in sequence. Builds and returns a list of temporal processing tasks to be executed in sequence.
@ -299,10 +315,14 @@ async def get_temporal_tasks(
user (User, optional): The user requesting task execution, used for permission checks. user (User, optional): The user requesting task execution, used for permission checks.
chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker. chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker.
chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default. chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default.
chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify
Returns: Returns:
list[Task]: A list of Task objects representing the temporal processing pipeline. list[Task]: A list of Task objects representing the temporal processing pipeline.
""" """
if chunks_per_batch is None:
chunks_per_batch = 10
temporal_tasks = [ temporal_tasks = [
Task(classify_documents), Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task(check_permissions_on_dataset, user=user, permissions=["write"]),
@ -311,9 +331,9 @@ async def get_temporal_tasks(
max_chunk_size=chunk_size or get_max_chunk_tokens(), max_chunk_size=chunk_size or get_max_chunk_tokens(),
chunker=chunker, chunker=chunker,
), ),
Task(extract_events_and_timestamps, task_config={"chunk_size": 10}), Task(extract_events_and_timestamps, task_config={"batch_size": chunks_per_batch}),
Task(extract_knowledge_graph_from_events), Task(extract_knowledge_graph_from_events),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
] ]
return temporal_tasks return temporal_tasks

View file

@ -0,0 +1,2 @@
from .get_cache_engine import get_cache_engine
from .config import get_cache_config

View file

@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
class CacheDBInterface(ABC):
"""
Abstract base class for distributed cache coordination systems (e.g., Redis, Memcached).
Provides a common interface for lock acquisition, release, and context-managed locking.
"""
def __init__(self, host: str, port: int, lock_key: str):
self.host = host
self.port = port
self.lock_key = lock_key
self.lock = None
@abstractmethod
def acquire_lock(self):
"""
Acquire a lock on the given key.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def release_lock(self):
"""
Release the lock if it is held.
Must be implemented by subclasses.
"""
pass
@contextmanager
def hold_lock(self):
"""
Context manager for safely acquiring and releasing the lock.
"""
self.acquire()
try:
yield
finally:
self.release()

View file

@ -0,0 +1,39 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from functools import lru_cache
class CacheConfig(BaseSettings):
"""
Configuration for distributed cache systems (e.g., Redis), used for locking or coordination.
Attributes:
- shared_kuzu_lock: Shared kuzu lock logic on/off.
- cache_host: Hostname of the cache service.
- cache_port: Port number for the cache service.
- agentic_lock_expire: Automatic lock expiration time (in seconds).
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
"""
caching: bool = False
shared_kuzu_lock: bool = False
cache_host: str = "localhost"
cache_port: int = 6379
agentic_lock_expire: int = 240
agentic_lock_timeout: int = 300
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
"caching": self.caching,
"shared_kuzu_lock": self.shared_kuzu_lock,
"cache_host": self.cache_host,
"cache_port": self.cache_port,
"agentic_lock_expire": self.agentic_lock_expire,
"agentic_lock_timeout": self.agentic_lock_timeout,
}
@lru_cache
def get_cache_config():
return CacheConfig()

View file

@ -0,0 +1,59 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
config = get_cache_config()
@lru_cache
def create_cache_engine(
cache_host: str,
cache_port: int,
lock_key: str,
agentic_lock_expire: int = 240,
agentic_lock_timeout: int = 300,
):
"""
Factory function to instantiate a cache coordination backend (currently Redis).
Parameters:
-----------
- cache_host: Hostname or IP of the cache server.
- cache_port: Port number to connect to.
- lock_key: Identifier used for the locking resource.
- agentic_lock_expire: Duration to hold the lock after acquisition.
- agentic_lock_timeout: Max time to wait for the lock before failing.
Returns:
--------
- CacheDBInterface: An instance of the appropriate cache adapter. :TODO: Now we support only Redis. later if we add more here we can split the logic
"""
if config.caching:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
return RedisAdapter(
host=cache_host,
port=cache_port,
lock_name=lock_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
else:
return None
def get_cache_engine(lock_key: str) -> CacheDBInterface:
"""
Returns a cache adapter instance using current context configuration.
"""
return create_cache_engine(
cache_host=config.cache_host,
cache_port=config.cache_port,
lock_key=lock_key,
agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout,
)

View file

@ -0,0 +1,49 @@
import redis
from contextlib import contextmanager
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
class RedisAdapter(CacheDBInterface):
def __init__(self, host, port, lock_name, timeout=240, blocking_timeout=300):
super().__init__(host, port, lock_name)
self.redis = redis.Redis(host=host, port=port)
self.timeout = timeout
self.blocking_timeout = blocking_timeout
def acquire_lock(self):
"""
Acquire the Redis lock manually. Raises if acquisition fails.
"""
self.lock = self.redis.lock(
name=self.lock_key,
timeout=self.timeout,
blocking_timeout=self.blocking_timeout,
)
acquired = self.lock.acquire()
if not acquired:
raise RuntimeError(f"Could not acquire Redis lock: {self.lock_key}")
return self.lock
def release_lock(self):
"""
Release the Redis lock manually, if held.
"""
if self.lock:
try:
self.lock.release()
self.lock = None
except redis.exceptions.LockError:
pass
@contextmanager
def hold_lock(self):
"""
Context manager for acquiring and releasing the Redis lock automatically.
"""
self.acquire()
try:
yield
finally:
self.release()

View file

@ -162,5 +162,5 @@ def create_graph_engine(
raise EnvironmentError( raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. " f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}" f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
) )

View file

@ -4,7 +4,7 @@ import os
import json import json
import asyncio import asyncio
import tempfile import tempfile
from uuid import UUID from uuid import UUID, uuid5, NAMESPACE_OID
from kuzu import Connection from kuzu import Connection
from kuzu.database import Database from kuzu.database import Database
from datetime import datetime, timezone from datetime import datetime, timezone
@ -23,9 +23,14 @@ from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp from cognee.tasks.temporal_graph.models import Timestamp
from cognee.infrastructure.databases.cache.config import get_cache_config
logger = get_logger() logger = get_logger()
cache_config = get_cache_config()
if cache_config.shared_kuzu_lock:
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
class KuzuAdapter(GraphDBInterface): class KuzuAdapter(GraphDBInterface):
""" """
@ -39,12 +44,20 @@ class KuzuAdapter(GraphDBInterface):
def __init__(self, db_path: str): def __init__(self, db_path: str):
"""Initialize Kuzu database connection and schema.""" """Initialize Kuzu database connection and schema."""
self.open_connections = 0
self._is_closed = False
self.db_path = db_path # Path for the database directory self.db_path = db_path # Path for the database directory
self.db: Optional[Database] = None self.db: Optional[Database] = None
self.connection: Optional[Connection] = None self.connection: Optional[Connection] = None
self.executor = ThreadPoolExecutor() if cache_config.shared_kuzu_lock:
self._initialize_connection() self.redis_lock = get_cache_engine(
lock_key="kuzu-lock-" + str(uuid5(NAMESPACE_OID, db_path))
)
else:
self.executor = ThreadPoolExecutor()
self._initialize_connection()
self.KUZU_ASYNC_LOCK = asyncio.Lock() self.KUZU_ASYNC_LOCK = asyncio.Lock()
self._connection_change_lock = asyncio.Lock()
def _initialize_connection(self) -> None: def _initialize_connection(self) -> None:
"""Initialize the Kuzu database connection and schema.""" """Initialize the Kuzu database connection and schema."""
@ -209,9 +222,13 @@ class KuzuAdapter(GraphDBInterface):
params = params or {} params = params or {}
def blocking_query(): def blocking_query():
lock_acquired = False
try: try:
if cache_config.shared_kuzu_lock:
self.redis_lock.acquire_lock()
lock_acquired = True
if not self.connection: if not self.connection:
logger.debug("Reconnecting to Kuzu database...") logger.info("Reconnecting to Kuzu database...")
self._initialize_connection() self._initialize_connection()
result = self.connection.execute(query, params) result = self.connection.execute(query, params)
@ -225,12 +242,47 @@ class KuzuAdapter(GraphDBInterface):
val = val.as_py() val = val.as_py()
processed_rows.append(val) processed_rows.append(val)
rows.append(tuple(processed_rows)) rows.append(tuple(processed_rows))
return rows return rows
except Exception as e: except Exception as e:
logger.error(f"Query execution failed: {str(e)}") logger.error(f"Query execution failed: {str(e)}")
raise raise
finally:
if cache_config.shared_kuzu_lock and lock_acquired:
try:
self.close()
finally:
self.redis_lock.release_lock()
return await loop.run_in_executor(self.executor, blocking_query) if cache_config.shared_kuzu_lock:
async with self._connection_change_lock:
self.open_connections += 1
logger.info(f"Open connections after open: {self.open_connections}")
try:
result = blocking_query()
finally:
self.open_connections -= 1
logger.info(f"Open connections after close: {self.open_connections}")
return result
else:
result = await loop.run_in_executor(self.executor, blocking_query)
return result
def close(self):
if self.connection:
del self.connection
self.connection = None
if self.db:
del self.db
self.db = None
self._is_closed = True
logger.info("Kuzu database closed successfully")
def reopen(self):
if self._is_closed:
self._is_closed = False
self._initialize_connection()
logger.info("Kuzu database re-opened successfully")
@asynccontextmanager @asynccontextmanager
async def get_session(self): async def get_session(self):
@ -1557,44 +1609,6 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to delete graph data: {e}") logger.error(f"Failed to delete graph data: {e}")
raise raise
async def clear_database(self) -> None:
"""
Clear all data from the database by deleting the database files and reinitializing.
This method removes all files associated with the database and reinitializes the Kuzu
database structure, ensuring a completely empty state. It handles exceptions that might
occur during file deletions or initializations carefully.
"""
try:
if self.connection:
self.connection = None
if self.db:
self.db.close()
self.db = None
db_dir = os.path.dirname(self.db_path)
db_name = os.path.basename(self.db_path)
file_storage = get_file_storage(db_dir)
if await file_storage.file_exists(db_name):
await file_storage.remove_all()
logger.info(f"Deleted Kuzu database files at {self.db_path}")
# Reinitialize the database
self._initialize_connection()
# Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0
if count > 0:
logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion"
)
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully")
except Exception as e:
logger.error(f"Error during database clearing: {e}")
raise
async def get_document_subgraph(self, data_id: str): async def get_document_subgraph(self, data_id: str):
""" """
Get all nodes that should be deleted when removing a document. Get all nodes that should be deleted when removing a document.

View file

@ -1067,7 +1067,7 @@ class Neo4jAdapter(GraphDBInterface):
query_nodes = f""" query_nodes = f"""
MATCH (n) MATCH (n)
WHERE {where_clause} WHERE {where_clause}
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties RETURN n.id AS id, labels(n) AS labels, properties(n) AS properties
""" """
result_nodes = await self.query(query_nodes) result_nodes = await self.query(query_nodes)
@ -1082,7 +1082,7 @@ class Neo4jAdapter(GraphDBInterface):
query_edges = f""" query_edges = f"""
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace("n.", "m.")} WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties RETURN n.id AS source, n.id AS target, TYPE(r) AS type, properties(r) AS properties
""" """
result_edges = await self.query(query_edges) result_edges = await self.query(query_edges)

View file

@ -1,8 +1,17 @@
from cognee.shared.logging_utils import get_logger import os
import logging
from typing import List, Optional from typing import List, Optional
from fastembed import TextEmbedding from fastembed import TextEmbedding
import litellm import litellm
import os from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.TikToken import ( from cognee.infrastructure.llm.tokenizer.TikToken import (
@ -57,6 +66,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
""" """
Embed the given text into numerical vectors. Embed the given text into numerical vectors.

View file

@ -1,15 +1,21 @@
import asyncio import asyncio
import logging
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
import math import math
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
import litellm import litellm
import os import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import (
GeminiTokenizer,
)
from cognee.infrastructure.llm.tokenizer.HuggingFace import ( from cognee.infrastructure.llm.tokenizer.HuggingFace import (
HuggingFaceTokenizer, HuggingFaceTokenizer,
) )
@ -19,10 +25,6 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
from cognee.infrastructure.llm.tokenizer.TikToken import ( from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer, TikTokenTokenizer,
) )
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
litellm.set_verbose = False litellm.set_verbose = False
logger = get_logger("LiteLLMEmbeddingEngine") logger = get_logger("LiteLLMEmbeddingEngine")
@ -76,8 +78,13 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
@embedding_sleep_and_retry_async() @retry(
@embedding_rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
""" """
Embed a list of text strings into vector representations. Embed a list of text strings into vector representations.

View file

@ -3,8 +3,16 @@ from cognee.shared.logging_utils import get_logger
import aiohttp import aiohttp
from typing import List, Optional from typing import List, Optional
import os import os
import litellm
import logging
import aiohttp.http_exceptions import aiohttp.http_exceptions
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.llm.tokenizer.HuggingFace import ( from cognee.infrastructure.llm.tokenizer.HuggingFace import (
@ -69,7 +77,6 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
""" """
Generate embedding vectors for a list of text prompts. Generate embedding vectors for a list of text prompts.
@ -92,7 +99,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text]) embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
return embeddings return embeddings
@embedding_sleep_and_retry_async() @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def _get_embedding(self, prompt: str) -> List[float]: async def _get_embedding(self, prompt: str) -> List[float]:
""" """
Internal method to call the Ollama embeddings endpoint for a single prompt. Internal method to call the Ollama embeddings endpoint for a single prompt.

View file

@ -24,11 +24,10 @@ class EmbeddingConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
def model_post_init(self, __context) -> None: def model_post_init(self, __context) -> None:
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai": if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
self.embedding_batch_size = 2048 self.embedding_batch_size = 36
elif not self.embedding_batch_size: elif not self.embedding_batch_size:
self.embedding_batch_size = 100 self.embedding_batch_size = 36
def to_dict(self) -> dict: def to_dict(self) -> dict:
""" """

View file

@ -124,6 +124,12 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
""" """
file_type = filetype.guess(file) file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
if file_type is None:
from filetype.types.base import Type
file_type = Type("text/plain", "txt")
if file_type is None: if file_type is None:
raise FileTypeException(f"Unknown file detected: {file.name}.") raise FileTypeException(f"Unknown file detected: {file.name}.")

View file

@ -1,19 +1,24 @@
import logging
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
import litellm
import instructor import instructor
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
class AnthropicAdapter(LLMInterface): class AnthropicAdapter(LLMInterface):
""" """
@ -35,8 +40,13 @@ class AnthropicAdapter(LLMInterface):
self.model = model self.model = model
self.max_completion_tokens = max_completion_tokens self.max_completion_tokens = max_completion_tokens
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( import logging
rate_limit_async, from cognee.shared.logging_utils import get_logger
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger()
class GeminiAdapter(LLMInterface): class GeminiAdapter(LLMInterface):
""" """
@ -58,8 +65,13 @@ class GeminiAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( import logging
rate_limit_async, from cognee.shared.logging_utils import get_logger
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger()
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
""" """
@ -58,8 +65,13 @@ class GenericAPIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -1,20 +1,23 @@
import litellm import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from typing import Type, Optional from typing import Type
from litellm import acompletion, JSONSchemaValidationError from litellm import JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async, import logging
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger() logger = get_logger()
@ -47,8 +50,13 @@ class MistralAdapter(LLMInterface):
api_key=get_llm_config().llm_api_key, api_key=get_llm_config().llm_api_key,
) )
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -99,31 +107,3 @@ class MistralAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}") logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}") logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}") raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -1,4 +1,6 @@
import base64 import base64
import litellm
import logging
import instructor import instructor
from typing import Type from typing import Type
from openai import OpenAI from openai import OpenAI
@ -7,11 +9,17 @@ from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
class OllamaAPIAdapter(LLMInterface): class OllamaAPIAdapter(LLMInterface):
@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
) )
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
return response return response
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str: async def create_transcript(self, input_file: str) -> str:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
return transcription.text return transcription.text
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str: async def transcribe_image(self, input_file: str) -> str:
""" """
Transcribe content from an image using base64 encoding. Transcribe content from an image using base64 encoding.

View file

@ -7,6 +7,15 @@ from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException from instructor.core import InstructorRetryException
import logging
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
@ -14,19 +23,13 @@ from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError, ContentPolicyFilterError,
) )
from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
observe = get_observe()
logger = get_logger() logger = get_logger()
observe = get_observe()
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):
""" """
@ -97,8 +100,13 @@ class OpenAIAdapter(LLMInterface):
self.fallback_endpoint = fallback_endpoint self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation") @observe(as_type="generation")
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -148,10 +156,7 @@ class OpenAIAdapter(LLMInterface):
InstructorRetryException, InstructorRetryException,
) as e: ) as e:
if not (self.fallback_model and self.fallback_api_key): if not (self.fallback_model and self.fallback_api_key):
raise ContentPolicyFilterError( raise e
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from e
try: try:
return await self.aclient.chat.completions.create( return await self.aclient.chat.completions.create(
model=self.fallback_model, model=self.fallback_model,
@ -186,8 +191,13 @@ class OpenAIAdapter(LLMInterface):
) from error ) from error
@observe @observe
@sleep_and_retry_sync() @retry(
@rate_limit_sync stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def create_structured_output( def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -231,7 +241,13 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
) )
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input): async def create_transcript(self, input):
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -263,7 +279,13 @@ class OpenAIAdapter(LLMInterface):
return transcription return transcription
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel: async def transcribe_image(self, input) -> BaseModel:
""" """
Generate a transcription of an image from a user query. Generate a transcription of an image from a user query.

View file

@ -105,7 +105,6 @@ class LoaderEngine:
async def load_file( async def load_file(
self, self,
file_path: str, file_path: str,
file_stream: Optional[Any],
preferred_loaders: Optional[List[str]] = None, preferred_loaders: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):

View file

@ -14,14 +14,6 @@ from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
logger = get_logger(__name__) logger = get_logger(__name__)
try:
from unstructured.partition.pdf import partition_pdf
except ImportError as e:
logger.info(
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
)
raise ImportError from e
@dataclass @dataclass
class _PageBuffer: class _PageBuffer:
@ -88,6 +80,8 @@ class AdvancedPdfLoader(LoaderInterface):
**kwargs, **kwargs,
} }
# Use partition to extract elements # Use partition to extract elements
from unstructured.partition.pdf import partition_pdf
elements = partition_pdf(**partition_kwargs) elements = partition_pdf(**partition_kwargs)
# Process elements into text content # Process elements into text content

View file

@ -35,6 +35,7 @@ async def run_pipeline(
vector_db_config: dict = None, vector_db_config: dict = None,
graph_db_config: dict = None, graph_db_config: dict = None,
incremental_loading: bool = False, incremental_loading: bool = False,
data_per_batch: int = 20,
): ):
validate_pipeline_tasks(tasks) validate_pipeline_tasks(tasks)
await setup_and_check_environment(vector_db_config, graph_db_config) await setup_and_check_environment(vector_db_config, graph_db_config)
@ -50,6 +51,7 @@ async def run_pipeline(
pipeline_name=pipeline_name, pipeline_name=pipeline_name,
context={"dataset": dataset}, context={"dataset": dataset},
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
): ):
yield run_info yield run_info
@ -62,6 +64,7 @@ async def run_pipeline_per_dataset(
pipeline_name: str = "custom_pipeline", pipeline_name: str = "custom_pipeline",
context: dict = None, context: dict = None,
incremental_loading=False, incremental_loading=False,
data_per_batch: int = 20,
): ):
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
await set_database_global_context_variables(dataset.id, dataset.owner_id) await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -77,7 +80,7 @@ async def run_pipeline_per_dataset(
return return
pipeline_run = run_tasks( pipeline_run = run_tasks(
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch
) )
async for pipeline_run_info in pipeline_run: async for pipeline_run_info in pipeline_run:

View file

@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
log_pipeline_run_complete, log_pipeline_run_complete,
log_pipeline_run_error, log_pipeline_run_error,
) )
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task from ..tasks.task import Task
@ -60,6 +59,7 @@ async def run_tasks(
pipeline_name: str = "unknown_pipeline", pipeline_name: str = "unknown_pipeline",
context: dict = None, context: dict = None,
incremental_loading: bool = False, incremental_loading: bool = False,
data_per_batch: int = 20,
): ):
if not user: if not user:
user = await get_default_user() user = await get_default_user()
@ -89,24 +89,29 @@ async def run_tasks(
if incremental_loading: if incremental_loading:
data = await resolve_data_directories(data) data = await resolve_data_directories(data)
# Create async tasks per data item that will run the pipeline for the data item # Create and gather batches of async tasks of data items that will run the pipeline for the data item
data_item_tasks = [ results = []
asyncio.create_task( for start in range(0, len(data), data_per_batch):
run_tasks_data_item( data_batch = data[start : start + data_per_batch]
data_item,
dataset, data_item_tasks = [
tasks, asyncio.create_task(
pipeline_name, run_tasks_data_item(
pipeline_id, data_item,
pipeline_run_id, dataset,
context, tasks,
user, pipeline_name,
incremental_loading, pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
)
) )
) for data_item in data_batch
for data_item in data ]
]
results = await asyncio.gather(*data_item_tasks) results.extend(await asyncio.gather(*data_item_tasks))
# Remove skipped data items from results # Remove skipped data items from results
results = [result for result in results if result] results = [result for result in results if result]

View file

@ -115,9 +115,8 @@ async def run_tasks_data_item_incremental(
data_point = ( data_point = (
await session.execute(select(Data).filter(Data.id == data_id)) await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none() ).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = { status_for_pipeline = data_point.pipeline_status.setdefault(pipeline_name, {})
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED status_for_pipeline[str(dataset.id)] = DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point) await session.merge(data_point)
await session.commit() await session.commit()

View file

@ -88,6 +88,7 @@ async def run_tasks_distributed(
pipeline_name: str = "unknown_pipeline", pipeline_name: str = "unknown_pipeline",
context: dict = None, context: dict = None,
incremental_loading: bool = False, incremental_loading: bool = False,
data_per_batch: int = 20,
): ):
if not user: if not user:
user = await get_default_user() user = await get_default_user()

View file

@ -1,6 +1,6 @@
from cognee.shared.logging_utils import get_logger import asyncio
from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -33,18 +33,23 @@ async def index_data_points(data_points: list[DataPoint]):
indexed_data_point.metadata["index_fields"] = [field_name] indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point) index_points[index_name].append(indexed_data_point)
for index_name_and_field, indexable_points in index_points.items(): tasks: list[asyncio.Task] = []
first_occurence = index_name_and_field.index("_") batch_size = vector_engine.embedding_engine.get_batch_size()
index_name = index_name_and_field[:first_occurence]
field_name = index_name_and_field[first_occurence + 1 :] for index_name_and_field, points in index_points.items():
try: first = index_name_and_field.index("_")
# In case the amount of indexable points is too large we need to send them in batches index_name = index_name_and_field[:first]
batch_size = vector_engine.embedding_engine.get_batch_size() field_name = index_name_and_field[first + 1 :]
for i in range(0, len(indexable_points), batch_size):
batch = indexable_points[i : i + batch_size] # Create embedding requests per batch to run in parallel later
await vector_engine.index_data_points(index_name, field_name, batch) for i in range(0, len(points), batch_size):
except EmbeddingException as e: batch = points[i : i + batch_size]
logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}") tasks.append(
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
)
# Run all embedding requests in parallel
await asyncio.gather(*tasks)
return data_points return data_points

View file

@ -1,3 +1,5 @@
import asyncio
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from collections import Counter from collections import Counter
@ -76,15 +78,20 @@ async def index_graph_edges(
indexed_data_point.metadata["index_fields"] = [field_name] indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point) index_points[index_name].append(indexed_data_point)
# Get maximum batch size for embedding model
batch_size = vector_engine.embedding_engine.get_batch_size()
tasks: list[asyncio.Task] = []
for index_name, indexable_points in index_points.items(): for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".") index_name, field_name = index_name.split(".")
# Get maximum batch size for embedding model # Create embedding tasks to run in parallel later
batch_size = vector_engine.embedding_engine.get_batch_size()
# We save the data in batches of {batch_size} to not put a lot of pressure on the database
for start in range(0, len(indexable_points), batch_size): for start in range(0, len(indexable_points), batch_size):
batch = indexable_points[start : start + batch_size] batch = indexable_points[start : start + batch_size]
await vector_engine.index_data_points(index_name, field_name, batch) tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
# Start all embedding tasks and wait for completion
await asyncio.gather(*tasks)
return None return None

View file

@ -0,0 +1,25 @@
import asyncio
import time
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
# This will create the test.db if it doesn't exist
async def main():
adapter = KuzuAdapter("test.db")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result[0][0]} nodes")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result[0][0]} nodes")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result[0][0]} nodes")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result[0][0]} nodes")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result} nodes")
result = await adapter.query("MATCH (n:Node) RETURN COUNT(n)")
print(f"Reader: Found {result[0][0]} nodes")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,31 @@
import asyncio
import cognee
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.api.v1.search import SearchType
async def main():
await cognee.cognify(datasets=["first_cognify_dataset"])
query_text = (
"Tell me what is in the context. Additionally write out 'FIRST_COGNIFY' before your answer"
)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query_text,
datasets=["first_cognify_dataset"],
)
print("Search results:")
for result_text in search_results:
print(result_text)
if __name__ == "__main__":
logger = setup_logging(log_level=INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -0,0 +1,31 @@
import asyncio
import cognee
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.api.v1.search import SearchType
async def main():
await cognee.cognify(datasets=["second_cognify_dataset"])
query_text = (
"Tell me what is in the context. Additionally write out 'SECOND_COGNIFY' before your answer"
)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query_text,
datasets=["second_cognify_dataset"],
)
print("Search results:")
for result_text in search_results:
print(result_text)
if __name__ == "__main__":
logger = setup_logging(log_level=INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -0,0 +1,32 @@
import asyncio
import time
import uuid
from cognee.modules.data.processing.document_types import PdfDocument
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
def create_node(name):
document = PdfDocument(
id=uuid.uuid4(),
name=name,
raw_data_location=name,
external_metadata="test_external_metadata",
mime_type="test_mime",
)
return document
async def main():
adapter = KuzuAdapter("test.db")
nodes = [create_node(f"Node{i}") for i in range(5)]
print("Writer: Starting...")
await adapter.add_nodes(nodes)
print("writer finished...")
time.sleep(10)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,7 +1,6 @@
from typing import List from typing import List
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.tasks.storage.add_data_points import add_data_points from cognee.tasks.storage.add_data_points import add_data_points
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
import cognee import cognee
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
import json import json
@ -64,7 +63,6 @@ async def create_connected_test_graph():
async def get_metrics(provider: str, include_optional=True): async def get_metrics(provider: str, include_optional=True):
create_graph_engine.cache_clear()
cognee.config.set_graph_database_provider(provider) cognee.config.set_graph_database_provider(provider)
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
await graph_engine.delete_graph() await graph_engine.delete_graph()

View file

@ -1,7 +1,12 @@
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import assert_metrics
import asyncio import asyncio
async def main():
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import assert_metrics
await assert_metrics(provider="neo4j", include_optional=False)
await assert_metrics(provider="neo4j", include_optional=True)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(assert_metrics(provider="neo4j", include_optional=False)) asyncio.run(main())
asyncio.run(assert_metrics(provider="neo4j", include_optional=True))

View file

@ -0,0 +1,84 @@
import os
import asyncio
import cognee
import pathlib
import subprocess
from cognee.shared.logging_utils import get_logger
logger = get_logger()
"""
Test: Redis-based Kùzu Locking Across Subprocesses
This test ensures the Redis shared lock correctly serializes access to the Kùzu
database when multiple subprocesses (writer/reader and cognify tasks) run in parallel.
If this test fails, it indicates the locking mechanism is not properly handling
concurrent subprocess access.
"""
async def concurrent_subprocess_access():
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/concurrent_tasks")
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/concurrent_tasks")
).resolve()
)
subprocess_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, "subprocesses/")).resolve()
)
writer_path = subprocess_directory_path + "/writer.py"
reader_path = subprocess_directory_path + "/reader.py"
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
writer_process = subprocess.Popen([os.sys.executable, str(writer_path)])
reader_process = subprocess.Popen([os.sys.executable, str(reader_path)])
# Wait for both processes to complete
writer_process.wait()
reader_process.wait()
logger.info("Basic write read subprocess example finished")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
text = """
This is the text of the first cognify subprocess
"""
await cognee.add(text, dataset_name="first_cognify_dataset")
text = """
This is the text of the second cognify subprocess
"""
await cognee.add(text, dataset_name="second_cognify_dataset")
first_cognify_path = subprocess_directory_path + "/simple_cognify_1.py"
second_cognify_path = subprocess_directory_path + "/simple_cognify_2.py"
first_cognify_process = subprocess.Popen([os.sys.executable, str(first_cognify_path)])
second_cognify_process = subprocess.Popen([os.sys.executable, str(second_cognify_path)])
# Wait for both processes to complete
first_cognify_process.wait()
second_cognify_process.wait()
logger.info("Database concurrent subprocess example finished")
if __name__ == "__main__":
asyncio.run(concurrent_subprocess_access())

View file

@ -1,105 +0,0 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
async def main():
cognee.config.set_graph_database_provider("memgraph")
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path_nlp = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path_nlp], dataset_name)
explanation_file_path_quantum = os.path.join(
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
)
await cognee.add([explanation_file_path_quantum], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.NATURAL_LANGUAGE,
query_text=f"Find nodes connected to node with name {random_node_name}",
)
assert len(search_results) != 0, "Query related natural language don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -0,0 +1,87 @@
"""Tests for cache configuration."""
import pytest
from cognee.infrastructure.databases.cache.config import CacheConfig, get_cache_config
def test_cache_config_defaults():
"""Test that CacheConfig has the correct default values."""
config = CacheConfig()
assert config.caching is False
assert config.shared_kuzu_lock is False
assert config.cache_host == "localhost"
assert config.cache_port == 6379
assert config.agentic_lock_expire == 240
assert config.agentic_lock_timeout == 300
def test_cache_config_custom_values():
"""Test that CacheConfig accepts custom values."""
config = CacheConfig(
caching=True,
shared_kuzu_lock=True,
cache_host="redis.example.com",
cache_port=6380,
agentic_lock_expire=120,
agentic_lock_timeout=180,
)
assert config.caching is True
assert config.shared_kuzu_lock is True
assert config.cache_host == "redis.example.com"
assert config.cache_port == 6380
assert config.agentic_lock_expire == 120
assert config.agentic_lock_timeout == 180
def test_cache_config_to_dict():
"""Test the to_dict method returns all configuration values."""
config = CacheConfig(
caching=True,
shared_kuzu_lock=True,
cache_host="test-host",
cache_port=7000,
agentic_lock_expire=100,
agentic_lock_timeout=200,
)
config_dict = config.to_dict()
assert config_dict == {
"caching": True,
"shared_kuzu_lock": True,
"cache_host": "test-host",
"cache_port": 7000,
"agentic_lock_expire": 100,
"agentic_lock_timeout": 200,
}
def test_get_cache_config_singleton():
"""Test that get_cache_config returns the same instance."""
config1 = get_cache_config()
config2 = get_cache_config()
assert config1 is config2
def test_cache_config_extra_fields_allowed():
"""Test that CacheConfig allows extra fields due to extra='allow'."""
config = CacheConfig(extra_field="extra_value", another_field=123)
assert hasattr(config, "extra_field")
assert config.extra_field == "extra_value"
assert hasattr(config, "another_field")
assert config.another_field == 123
def test_cache_config_boolean_type_validation():
"""Test that boolean fields accept various truthy/falsy values."""
config1 = CacheConfig(caching="true", shared_kuzu_lock="yes")
assert config1.caching is True
assert config1.shared_kuzu_lock is True
config2 = CacheConfig(caching="false", shared_kuzu_lock="no")
assert config2.caching is False
assert config2.shared_kuzu_lock is False

View file

@ -129,6 +129,30 @@ services:
networks: networks:
- cognee-network - cognee-network
redis:
image: redis:7-alpine
container_name: redis
profiles:
- redis
ports:
- "6379:6379"
networks:
- cognee-network
volumes:
- redis_data:/data
command: [ "redis-server", "--appendonly", "yes" ]
redisinsight:
image: redislabs/redisinsight:latest
container_name: redisinsight
restart: always
ports:
- "5540:5540"
networks:
- cognee-network
networks: networks:
cognee-network: cognee-network:
name: cognee-network name: cognee-network
@ -136,3 +160,4 @@ networks:
volumes: volumes:
chromadb_data: chromadb_data:
postgres_data: postgres_data:
redis_data:

View file

@ -83,16 +83,16 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import pathlib\n", "import pathlib\n",
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n", "from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
"from dotenv import load_dotenv" "from dotenv import load_dotenv"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -106,7 +106,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# load environment variables from file .env\n", "# load environment variables from file .env\n",
"load_dotenv()\n", "load_dotenv()\n",
@ -145,9 +147,7 @@
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n", " \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
" }\n", " }\n",
")" ")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -159,19 +159,19 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n", "# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
"await prune.prune_data()\n", "await prune.prune_data()\n",
"await prune.prune_system(metadata=True)" "await prune.prune_system(metadata=True)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Setup data and cognify\n", "## Setup data and cognify\n",
"\n", "\n",
@ -180,7 +180,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Add sample text to the dataset\n", "# Add sample text to the dataset\n",
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n", "sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
@ -205,9 +207,7 @@
"\n", "\n",
"# Cognify the text data.\n", "# Cognify the text data.\n",
"await cognify([dataset_name])" "await cognify([dataset_name])"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -215,14 +215,16 @@
"source": [ "source": [
"## Graph Memory visualization\n", "## Graph Memory visualization\n",
"\n", "\n",
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n", "Initialize Neptune as a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"\n", "\n",
"![visualization](./neptune_analytics_demo.png)" "![visualization](./neptune_analytics_demo.png)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n", "# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
"# url = await render_graph()\n", "# url = await render_graph()\n",
@ -235,9 +237,7 @@
" ).resolve()\n", " ).resolve()\n",
")\n", ")\n",
"await visualize_graph(graph_file_path)" "await visualize_graph(graph_file_path)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -250,19 +250,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Completion query that uses graph data to form context.\n", "# Completion query that uses graph data to form context.\n",
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n", "graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
"print(\"\\nGraph completion result is:\")\n", "print(\"\\nGraph completion result is:\")\n",
"print(graph_completion)" "print(graph_completion)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: RAG Completion\n", "## SEARCH: RAG Completion\n",
"\n", "\n",
@ -271,19 +271,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Completion query that uses document chunks to form context.\n", "# Completion query that uses document chunks to form context.\n",
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n", "rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
"print(\"\\nRAG Completion result is:\")\n", "print(\"\\nRAG Completion result is:\")\n",
"print(rag_completion)" "print(rag_completion)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Graph Insights\n", "## SEARCH: Graph Insights\n",
"\n", "\n",
@ -291,8 +291,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Search graph insights\n", "# Search graph insights\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.GRAPH_COMPLETION)\n", "insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.GRAPH_COMPLETION)\n",
@ -302,13 +304,11 @@
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n", " tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n", " relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")" " print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Entity Summaries\n", "## SEARCH: Entity Summaries\n",
"\n", "\n",
@ -316,8 +316,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Query all summaries related to query.\n", "# Query all summaries related to query.\n",
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n", "summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
@ -326,13 +328,11 @@
" type = summary[\"type\"]\n", " type = summary[\"type\"]\n",
" text = summary[\"text\"]\n", " text = summary[\"text\"]\n",
" print(f\"- {type}: {text}\")" " print(f\"- {type}: {text}\")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Chunks\n", "## SEARCH: Chunks\n",
"\n", "\n",
@ -340,8 +340,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n", "chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
"print(\"\\nChunk results are:\")\n", "print(\"\\nChunk results are:\")\n",
@ -349,9 +351,7 @@
" type = chunk[\"type\"]\n", " type = chunk[\"type\"]\n",
" text = chunk[\"text\"]\n", " text = chunk[\"text\"]\n",
" print(f\"- {type}: {text}\")" " print(f\"- {type}: {text}\")"
], ]
"outputs": [],
"execution_count": null
} }
], ],
"metadata": { "metadata": {

1319
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,7 @@ authors = [
{ name = "Vasilije Markovic" }, { name = "Vasilije Markovic" },
{ name = "Boris Arzentar" }, { name = "Boris Arzentar" },
] ]
requires-python = ">=3.10,<=3.13" requires-python = ">=3.10,<3.14"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"
classifiers = [ classifiers = [
@ -56,6 +56,7 @@ dependencies = [
"gunicorn>=20.1.0,<24", "gunicorn>=20.1.0,<24",
"websockets>=15.0.1,<16.0.0", "websockets>=15.0.1,<16.0.0",
"mistralai>=1.9.10", "mistralai>=1.9.10",
"tenacity>=9.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@ -64,14 +65,16 @@ api=[]
distributed = [ distributed = [
"modal>=1.0.5,<2.0.0", "modal>=1.0.5,<2.0.0",
] ]
scraping = [ scraping = [
"tavily-python>=0.7.0", "tavily-python>=0.7.12",
"beautifulsoup4>=4.13.1", "beautifulsoup4>=4.13.1",
"playwright>=1.9.0", "playwright>=1.9.0",
"lxml>=4.9.3,<5.0.0", "lxml>=4.9.3",
"protego>=0.1", "protego>=0.1",
"APScheduler>=3.10.0,<=3.11.0" "APScheduler>=3.10.0,<=3.11.0"
] ]
neo4j = ["neo4j>=5.28.0,<6"] neo4j = ["neo4j>=5.28.0,<6"]
neptune = ["langchain_aws>=0.2.22"] neptune = ["langchain_aws>=0.2.22"]
postgres = [ postgres = [
@ -101,7 +104,7 @@ chromadb = [
"chromadb>=0.6,<0.7", "chromadb>=0.6,<0.7",
"pypika==0.48.9", "pypika==0.48.9",
] ]
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"] docs = ["lxml<6.0.0", "unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"]
codegraph = [ codegraph = [
"fastembed<=0.6.0 ; python_version < '3.13'", "fastembed<=0.6.0 ; python_version < '3.13'",
"transformers>=4.46.3,<5", "transformers>=4.46.3,<5",
@ -140,6 +143,7 @@ dev = [
"mkdocstrings[python]>=0.26.2,<0.27", "mkdocstrings[python]>=0.26.2,<0.27",
] ]
debug = ["debugpy>=1.8.9,<2.0.0"] debug = ["debugpy>=1.8.9,<2.0.0"]
redis = ["redis>=5.0.3,<6.0.0"]
monitoring = ["sentry-sdk[fastapi]>=2.9.0,<3", "langfuse>=2.32.0,<3"] monitoring = ["sentry-sdk[fastapi]>=2.9.0,<3", "langfuse>=2.32.0,<3"]

2477
uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,31 @@
"""
Run writer and reader in separate subprocesses to test Kuzu locks.
"""
import subprocess
import time
import os
def main():
print("=== Kuzu Subprocess Lock Test ===")
print("Starting writer and reader in separate subprocesses...")
print("Writer will hold the database lock, reader should block or fail\n")
start_time = time.time()
# Start writer subprocess
writer_process = subprocess.Popen([os.sys.executable, "writer.py"])
reader_process = subprocess.Popen([os.sys.executable, "reader.py"])
# Wait for both processes to complete
writer_process.wait()
reader_process.wait()
total_time = time.time() - start_time
print(f"\nTotal execution time: {total_time:.2f}s")
if __name__ == "__main__":
main()